diff --git a/.github/workflows/ci-build-test-reusable.yml b/.github/workflows/ci-build-test-reusable.yml index 40c86e43b..91c76a363 100644 --- a/.github/workflows/ci-build-test-reusable.yml +++ b/.github/workflows/ci-build-test-reusable.yml @@ -9,9 +9,6 @@ on: version_toolchain: type: string required: true - taskdb: - type: string - required: true env: CI: 1 @@ -34,11 +31,6 @@ jobs: with: submodules: recursive - - name: taskdb check - if: ${{ inputs.taskdb != '' }} - run: | - echo "TASKDB=${{ inputs.taskdb }}" >> $GITHUB_ENV - - uses: actions-rs/toolchain@v1 with: toolchain: ${{ inputs.version_toolchain }} diff --git a/.github/workflows/ci-integration-reusable.yml b/.github/workflows/ci-integration-reusable.yml index a330dbc30..ef91ff095 100644 --- a/.github/workflows/ci-integration-reusable.yml +++ b/.github/workflows/ci-integration-reusable.yml @@ -9,9 +9,6 @@ on: version_toolchain: type: string required: true - taskdb: - type: string - required: true env: @@ -36,11 +33,6 @@ jobs: with: submodules: recursive - - name: taskdb check - if: ${{ inputs.taskdb != '' }} - run: | - echo "TASKDB=${{ inputs.taskdb }}" >> $GITHUB_ENV - - uses: actions-rs/toolchain@v1 with: toolchain: ${{ inputs.version_toolchain }} diff --git a/.github/workflows/ci-native.yml b/.github/workflows/ci-native.yml index a63cc3256..de09ac6fa 100644 --- a/.github/workflows/ci-native.yml +++ b/.github/workflows/ci-native.yml @@ -3,46 +3,18 @@ name: CI - Native on: workflow_call: pull_request: - paths: - - "taskdb/**" jobs: - set-taskdb: - runs-on: ubuntu-latest - outputs: - taskdb: ${{ steps.check_file.outputs.taskdb }} - steps: - - name: Checkout code - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Check if specific file changed - id: check_file - run: | - BASE_BRANCH=${{ github.event.pull_request.base.ref }} - if git diff --name-only origin/$BASE_BRANCH ${{ github.sha }} | grep -q "taskdb/src/redis_db.rs"; then - echo "redis changed" - echo "::set-output name=taskdb::raiko-tasks/redis-db" - else - echo "redis unchanged" - echo "::set-output name=taskdb::" - fi - build-test-native: name: Build and test native - needs: set-taskdb uses: ./.github/workflows/ci-build-test-reusable.yml with: version_name: "native" version_toolchain: "nightly-2024-04-17" - taskdb: ${{ needs.set-taskdb.outputs.taskdb }} integration-test-native: name: Run integration tests on native - needs: set-taskdb uses: ./.github/workflows/ci-integration-reusable.yml with: version_name: "native" version_toolchain: "nightly-2024-04-17" - taskdb: ${{ needs.set-taskdb.outputs.taskdb }} diff --git a/Dockerfile b/Dockerfile index 23ddd3032..c2ee57135 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,6 @@ FROM rust:1.81.0 AS builder ENV DEBIAN_FRONTEND=noninteractive ARG BUILD_FLAGS="" -ARG TASKDB=${TASKDB:-raiko-tasks/in-memory} # risc0 dependencies # RUN curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash && \ @@ -11,8 +10,7 @@ ARG TASKDB=${TASKDB:-raiko-tasks/in-memory} WORKDIR /opt/raiko COPY . . -RUN echo "Building for sgx with taskdb: ${TASKDB}" -RUN cargo build --release ${BUILD_FLAGS} --features "sgx" --features "docker_build" --features ${TASKDB} +RUN cargo build --release ${BUILD_FLAGS} --features "sgx" --features "docker_build" FROM gramineproject/gramine:1.8-jammy AS runtime ENV DEBIAN_FRONTEND=noninteractive diff --git a/Dockerfile.zk b/Dockerfile.zk index 8c5ff6a88..1ccaa2d4e 100644 --- a/Dockerfile.zk +++ b/Dockerfile.zk @@ -32,7 +32,6 @@ FROM base-builder AS builder ENV DEBIAN_FRONTEND=noninteractive ARG BUILD_FLAGS="" -ARG TASKDB=${TASKDB:-raiko-tasks/in-memory} WORKDIR /opt/raiko # build related files @@ -63,7 +62,7 @@ RUN make guest RUN echo "Building for sp1" ENV TARGET=sp1 RUN make guest -RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0" --features "docker_build" --features ${TASKDB} +RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0" --features "docker_build" FROM ubuntu:22.04 AS raiko-zk RUN mkdir -p \ diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index a4009cb45..c724eafe6 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -25,7 +25,6 @@ services: context: .. args: ENABLE_SELF_REGISTER: "true" - TASKDB: ${TASKDB:-raiko-tasks/in-memory} dockerfile: Dockerfile image: us-docker.pkg.dev/evmchain/images/raiko:latest container_name: raiko-init-self-register @@ -54,8 +53,6 @@ services: raiko: build: context: .. - args: - TASKDB: ${TASKDB:-raiko-tasks/in-memory} dockerfile: Dockerfile image: us-docker.pkg.dev/evmchain/images/raiko:latest container_name: raiko @@ -98,8 +95,6 @@ services: build: context: .. dockerfile: Dockerfile - args: - TASKDB: ${TASKDB:-raiko-tasks/in-memory} image: us-docker.pkg.dev/evmchain/images/raiko:latest container_name: raiko-self-register command: --config-path=/etc/raiko/config.sgx.json --chain-spec-path=/etc/raiko/chain_spec_list.docker.json @@ -136,8 +131,6 @@ services: build: context: .. dockerfile: Dockerfile.zk - args: - TASKDB: ${TASKDB:-raiko-tasks/in-memory} image: us-docker.pkg.dev/evmchain/images/raiko-zk:latest container_name: raiko-zk command: --config-path=/etc/raiko/config.sgx.json --chain-spec-path=/etc/raiko/chain_spec_list.docker.json diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 301f370c4..7c008f3ff 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -1,7 +1,7 @@ use axum::response::IntoResponse; use raiko_lib::proof_type::ProofType; use raiko_lib::prover::ProverError; -use raiko_tasks::{TaskManagerError, TaskStatus}; +use raiko_tasks::TaskStatus; use tokio::sync::mpsc::error::TrySendError; use utoipa::ToSchema; @@ -69,10 +69,6 @@ pub enum HostError { #[schema(value_type = Value)] Anyhow(#[from] anyhow::Error), - /// For task manager errors. - #[error("There was an error with the task manager: {0}")] - TaskManager(#[from] TaskManagerError), - /// For system paused state. #[error("System is paused")] SystemPaused, @@ -91,7 +87,6 @@ impl IntoResponse for HostError { HostError::Guest(e) => ("guest_error", e.to_string()), HostError::Core(e) => ("core_error", e.to_string()), HostError::FeatureNotSupportedError(e) => ("feature_not_supported", e.to_string()), - HostError::TaskManager(e) => ("task_manager", e.to_string()), HostError::Anyhow(e) => ("anyhow_error", e.to_string()), HostError::HandleDropped => ("handle_dropped", "".to_owned()), HostError::CapacityFull => ("capacity_full", "".to_owned()), @@ -134,7 +129,6 @@ impl From for TaskStatus { HostError::Io(e) => TaskStatus::IoFailure(e.to_string()), HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()), HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()), - HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()), HostError::SystemPaused => TaskStatus::SystemPaused, } } @@ -156,7 +150,6 @@ impl From<&HostError> for TaskStatus { HostError::Io(e) => TaskStatus::GuestProverFailure(e.to_string()), HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()), HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()), - HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()), HostError::SystemPaused => TaskStatus::SystemPaused, } } diff --git a/host/src/lib.rs b/host/src/lib.rs index f7f9408e1..404d81d98 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -1,26 +1,18 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::{alloc, path::PathBuf}; use anyhow::Context; use cap::Cap; use clap::Parser; -use raiko_core::{ - interfaces::{AggregationOnlyRequest, ProofRequest, ProofRequestOpt}, - merge, -}; +use raiko_core::{interfaces::ProofRequestOpt, merge}; use raiko_lib::consts::SupportedChainSpecs; -use raiko_tasks::{get_task_manager, ProofTaskDescriptor, TaskManagerOpts, TaskManagerWrapperImpl}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tokio::sync::mpsc; -use crate::{interfaces::HostResult, proof::ProofActor}; +use crate::interfaces::HostResult; pub mod cache; pub mod interfaces; pub mod metrics; -pub mod proof; pub mod server; #[derive(Default, Clone, Serialize, Deserialize, Debug, Parser)] @@ -131,112 +123,6 @@ impl Opts { } } -impl From for TaskManagerOpts { - fn from(val: Opts) -> Self { - Self { - max_db_size: val.max_db_size, - redis_url: val.redis_url.to_string(), - redis_ttl: val.redis_ttl, - } - } -} - -impl From<&Opts> for TaskManagerOpts { - fn from(val: &Opts) -> Self { - Self { - max_db_size: val.max_db_size, - redis_url: val.redis_url.to_string(), - redis_ttl: val.redis_ttl, - } - } -} - -#[derive(Debug, Clone)] -pub struct ProverState { - pub opts: Opts, - pub chain_specs: SupportedChainSpecs, - pub task_channel: mpsc::Sender, - pause_flag: Arc, -} - -#[derive(Debug)] -pub enum Message { - Cancel(ProofTaskDescriptor), - Task(ProofRequest), - TaskComplete(ProofRequest), - CancelAggregate(AggregationOnlyRequest), - Aggregate(AggregationOnlyRequest), - SystemPause(tokio::sync::oneshot::Sender>), -} - -impl ProverState { - pub fn init() -> HostResult { - let opts = parse_opts()?; - Self::init_with_opts(opts) - } - - pub fn init_with_opts(opts: Opts) -> HostResult { - // Check if the cache path exists and create it if it doesn't. - if let Some(cache_path) = &opts.cache_path { - if !cache_path.exists() { - std::fs::create_dir_all(cache_path).context("Could not create cache dir")?; - } - } - - let (task_channel, receiver) = mpsc::channel::(opts.concurrency_limit); - let pause_flag = Arc::new(AtomicBool::new(false)); - - let opts_clone = opts.clone(); - let chain_specs = parse_chain_specs(&opts); - let chain_specs_clone = chain_specs.clone(); - let sender = task_channel.clone(); - tokio::spawn(async move { - ProofActor::new(sender, receiver, opts_clone, chain_specs_clone) - .run() - .await; - }); - - Ok(Self { - opts, - chain_specs, - task_channel, - pause_flag, - }) - } - - pub fn task_manager(&self) -> TaskManagerWrapperImpl { - get_task_manager(&(&self.opts).into()) - } - - pub fn request_config(&self) -> ProofRequestOpt { - self.opts.proof_request_opt.clone() - } - - pub fn is_paused(&self) -> bool { - self.pause_flag.load(Ordering::SeqCst) - } - - /// Set the pause flag and notify the task manager to pause, then wait for the task manager to - /// finish the pause process. - /// - /// Note that this function is blocking until the task manager finishes the pause process. - pub async fn set_pause(&self, paused: bool) -> HostResult<()> { - self.pause_flag.store(paused, Ordering::SeqCst); - if paused { - // Notify task manager to start pause process - let (sender, receiver) = tokio::sync::oneshot::channel(); - self.task_channel - .try_send(Message::SystemPause(sender)) - .context("Failed to send pause message")?; - - // Wait for the pause message to be processed - let result = receiver.await.context("Failed to receive pause message")?; - return result; - } - Ok(()) - } -} - pub fn parse_opts() -> HostResult { // Read the command line arguments; let mut opts = Opts::parse(); @@ -259,6 +145,7 @@ pub fn parse_chain_specs(opts: &Opts) -> SupportedChainSpecs { #[global_allocator] static ALLOCATOR: Cap = Cap::new(alloc::System, usize::MAX); +#[allow(unused)] mod memory { use tracing::debug; diff --git a/host/src/proof.rs b/host/src/proof.rs deleted file mode 100644 index 7453fc4b8..000000000 --- a/host/src/proof.rs +++ /dev/null @@ -1,573 +0,0 @@ -use std::{ - collections::{HashMap, VecDeque}, - str::FromStr, - sync::Arc, -}; - -use raiko_core::{ - interfaces::{ - aggregate_proofs, cancel_proof, AggregationOnlyRequest, ProofRequest, RaikoError, - }, - provider::{get_task_data, rpc::RpcBlockDataProvider}, - Raiko, -}; -use raiko_lib::{ - consts::SupportedChainSpecs, - input::{AggregationGuestInput, AggregationGuestOutput}, - proof_type::ProofType, - prover::{IdWrite, Proof}, - Measurement, -}; -use raiko_tasks::{ - get_task_manager, ProofTaskDescriptor, TaskManager, TaskManagerWrapperImpl, TaskStatus, -}; -use reth_primitives::B256; -use tokio::{ - select, - sync::{ - mpsc::{Receiver, Sender}, - Mutex, OwnedSemaphorePermit, Semaphore, - }, -}; -use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, warn}; - -use crate::{ - cache, - interfaces::{HostError, HostResult}, - memory, - metrics::{ - inc_guest_error, inc_guest_success, inc_host_error, observe_guest_time, - observe_prepare_input_time, observe_total_time, - }, - Message, Opts, -}; - -pub struct ProofActor { - opts: Opts, - chain_specs: SupportedChainSpecs, - aggregate_tasks: Arc>>, - running_tasks: Arc>>, - pending_tasks: Arc>>, - receiver: Receiver, - sender: Sender, -} - -impl ProofActor { - pub fn new( - sender: Sender, - receiver: Receiver, - opts: Opts, - chain_specs: SupportedChainSpecs, - ) -> Self { - let running_tasks = Arc::new(Mutex::new( - HashMap::::new(), - )); - let aggregate_tasks = Arc::new(Mutex::new(HashMap::< - AggregationOnlyRequest, - CancellationToken, - >::new())); - let pending_tasks = Arc::new(Mutex::new(VecDeque::::new())); - - Self { - opts, - chain_specs, - aggregate_tasks, - running_tasks, - pending_tasks, - receiver, - sender, - } - } - - pub async fn cancel_task(&mut self, key: ProofTaskDescriptor) -> HostResult<()> { - let task = { - let tasks_map = self.running_tasks.lock().await; - match tasks_map.get(&key) { - Some(task) => task.to_owned(), - None => { - warn!("No task with those keys to cancel"); - return Ok(()); - } - } - }; - - let mut manager = get_task_manager(&self.opts.clone().into()); - cancel_proof( - key.proof_system, - ( - key.chain_id, - key.block_id, - key.blockhash, - key.proof_system as u8, - ), - Box::new(&mut manager), - ) - .await - .or_else(|e| { - if e.to_string().contains("No data for query") { - warn!("Task already cancelled or not yet started!"); - Ok(()) - } else { - Err::<(), HostError>(e.into()) - } - })?; - task.cancel(); - Ok(()) - } - - pub async fn run_task(&mut self, proof_request: ProofRequest) { - let cancel_token = CancellationToken::new(); - - let (chain_id, blockhash) = match get_task_data( - &proof_request.network, - proof_request.block_number, - &self.chain_specs, - ) - .await - { - Ok(v) => v, - Err(e) => { - error!("Could not get task data for {proof_request:?}, error: {e}"); - return; - } - }; - - let key = ProofTaskDescriptor::from(( - chain_id, - proof_request.block_number, - blockhash, - proof_request.proof_type, - proof_request.prover.clone().to_string(), - )); - - { - let mut tasks = self.running_tasks.lock().await; - tasks.insert(key.clone(), cancel_token.clone()); - } - - let sender = self.sender.clone(); - let tasks = self.running_tasks.clone(); - let opts = self.opts.clone(); - let chain_specs = self.chain_specs.clone(); - - tokio::spawn(async move { - select! { - _ = cancel_token.cancelled() => { - info!("Task cancelled"); - } - result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => { - match result { - Ok(status) => { - info!("Host handling message: {status:?}"); - } - Err(error) => { - error!("Worker failed due to: {error:?}"); - } - }; - } - } - let mut tasks = tasks.lock().await; - tasks.remove(&key); - // notify complete task to let next pending task run - sender - .send(Message::TaskComplete(proof_request)) - .await - .expect("Couldn't send message"); - }); - } - - pub async fn cancel_aggregation_task( - &mut self, - request: AggregationOnlyRequest, - ) -> HostResult<()> { - let tasks_map = self.aggregate_tasks.lock().await; - let Some(task) = tasks_map.get(&request) else { - warn!("No task with those keys to cancel"); - return Ok(()); - }; - - // TODO:(petar) implement cancel_proof_aggregation - // let mut manager = get_task_manager(&self.opts.clone().into()); - // let proof_type = ProofType::from_str( - // request - // .proof_type - // .as_ref() - // .ok_or_else(|| anyhow!("No proof type"))?, - // )?; - // proof_type - // .cancel_proof_aggregation(request, Box::new(&mut manager)) - // .await - // .or_else(|e| { - // if e.to_string().contains("No data for query") { - // warn!("Task already cancelled or not yet started!"); - // Ok(()) - // } else { - // Err::<(), HostError>(e.into()) - // } - // })?; - task.cancel(); - Ok(()) - } - - pub async fn run_aggregate( - &mut self, - request: AggregationOnlyRequest, - _permit: OwnedSemaphorePermit, - ) { - let cancel_token = CancellationToken::new(); - - let mut tasks = self.aggregate_tasks.lock().await; - tasks.insert(request.clone(), cancel_token.clone()); - - let request_clone = request.clone(); - let tasks = self.aggregate_tasks.clone(); - let opts = self.opts.clone(); - - tokio::spawn(async move { - select! { - _ = cancel_token.cancelled() => { - info!("Task cancelled"); - } - result = Self::handle_aggregate(request_clone, &opts) => { - match result { - Ok(status) => { - info!("Host handling message: {status:?}"); - } - Err(error) => { - error!("Worker failed due to: {error:?}"); - } - }; - } - } - let mut tasks = tasks.lock().await; - tasks.remove(&request); - }); - } - - pub async fn run(&mut self) { - // recv() is protected by outside mpsc, no lock needed here - let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); - while let Some(message) = self.receiver.recv().await { - match message { - Message::Cancel(key) => { - debug!("Message::Cancel({key:?})"); - if let Err(error) = self.cancel_task(key).await { - error!("Failed to cancel task: {error}") - } - } - Message::Task(proof_request) => { - debug!("Message::Task({proof_request:?})"); - let running_task_count = self.running_tasks.lock().await.len(); - if running_task_count < self.opts.concurrency_limit { - info!("Running task {proof_request:?}"); - self.run_task(proof_request).await; - } else { - info!( - "Task concurrency status: running:{running_task_count:?}, add {proof_request:?} to pending list[{:?}]", - self.pending_tasks.lock().await.len() - ); - let mut pending_tasks = self.pending_tasks.lock().await; - pending_tasks.push_back(proof_request); - } - } - Message::TaskComplete(req) => { - // pop up pending task if any task complete - debug!("Message::TaskComplete({req:?})"); - info!( - "task {req:?} completed, current running {:?}, pending: {:?}", - self.running_tasks.lock().await.len(), - self.pending_tasks.lock().await.len() - ); - let mut pending_tasks = self.pending_tasks.lock().await; - if let Some(proof_request) = pending_tasks.pop_front() { - info!("Pop out pending task {proof_request:?}"); - self.sender - .send(Message::Task(proof_request)) - .await - .expect("Couldn't send message"); - } - } - Message::CancelAggregate(request) => { - debug!("Message::CancelAggregate({request:?})"); - if let Err(error) = self.cancel_aggregation_task(request).await { - error!("Failed to cancel task: {error}") - } - } - Message::Aggregate(request) => { - debug!("Message::Aggregate({request:?})"); - let permit = Arc::clone(&semaphore) - .acquire_owned() - .await - .expect("Couldn't acquire permit"); - self.run_aggregate(request, permit).await; - } - Message::SystemPause(notifier) => { - let result = self.handle_system_pause().await; - let _ = notifier.send(result); - } - } - } - } - - pub async fn handle_message( - proof_request: ProofRequest, - key: ProofTaskDescriptor, - opts: &Opts, - chain_specs: &SupportedChainSpecs, - ) -> HostResult { - let mut manager = get_task_manager(&opts.clone().into()); - - let status = manager.get_task_proving_status(&key).await?; - - if let Some(latest_status) = status.0.iter().last() { - if !matches!(latest_status.0, TaskStatus::Registered) { - return Ok(latest_status.0.clone()); - } - } - - manager - .update_task_progress(key.clone(), TaskStatus::WorkInProgress, None) - .await?; - - let (status, proof) = - match handle_proof(&proof_request, opts, chain_specs, Some(&mut manager)).await { - Err(error) => { - error!("{error}"); - (error.into(), None) - } - Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), - }; - - manager - .update_task_progress(key, status.clone(), proof.as_deref()) - .await - .map_err(HostError::from)?; - Ok(status) - } - - pub async fn handle_aggregate(request: AggregationOnlyRequest, opts: &Opts) -> HostResult<()> { - let proof_type_str = request.proof_type.to_owned().unwrap_or_default(); - let proof_type = ProofType::from_str(&proof_type_str).map_err(HostError::Conversion)?; - - let mut manager = get_task_manager(&opts.clone().into()); - - let status = manager - .get_aggregation_task_proving_status(&request) - .await?; - - if let Some(latest_status) = status.0.iter().last() { - if !matches!(latest_status.0, TaskStatus::Registered) { - return Ok(()); - } - } - - manager - .update_aggregation_task_progress(&request, TaskStatus::WorkInProgress, None) - .await?; - - let input = AggregationGuestInput { - proofs: request.clone().proofs, - }; - let output = AggregationGuestOutput { hash: B256::ZERO }; - let config = serde_json::to_value(request.clone().prover_args)?; - let mut manager = get_task_manager(&opts.clone().into()); - - let (status, proof) = - match aggregate_proofs(proof_type, input, &output, &config, Some(&mut manager)).await { - Err(error) => { - error!("{error}"); - (HostError::from(error).into(), None) - } - Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), - }; - - manager - .update_aggregation_task_progress(&request, status, proof.as_deref()) - .await?; - - Ok(()) - } - - async fn cancel_all_running_tasks(&mut self) -> HostResult<()> { - info!("Cancelling all running tasks"); - - // Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other - // internal functions. - let running_tasks = { - let running_tasks = self.running_tasks.lock().await; - (*running_tasks).clone() - }; - - // Cancel all running tasks, don't stop even if any task fails. - let mut final_result = Ok(()); - for proof_task_descriptor in running_tasks.keys() { - match self.cancel_task(proof_task_descriptor.clone()).await { - Ok(()) => { - info!( - "Cancel task during system pause, task: {:?}", - proof_task_descriptor - ); - } - Err(e) => { - error!( - "Failed to cancel task during system pause: {}, task: {:?}", - e, proof_task_descriptor - ); - final_result = final_result.and(Err(e)); - } - } - } - final_result - } - - async fn cancel_all_aggregation_tasks(&mut self) -> HostResult<()> { - info!("Cancelling all aggregation tasks"); - - // Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other - // internal functions. - let aggregate_tasks = { - let aggregate_tasks = self.aggregate_tasks.lock().await; - (*aggregate_tasks).clone() - }; - - // Cancel all aggregation tasks, don't stop even if any task fails. - let mut final_result = Ok(()); - for request in aggregate_tasks.keys() { - match self.cancel_aggregation_task(request.clone()).await { - Ok(()) => { - info!( - "Cancel aggregation task during system pause, task: {}", - request - ); - } - Err(e) => { - error!( - "Failed to cancel aggregation task during system pause: {}, task: {}", - e, request - ); - final_result = final_result.and(Err(e)); - } - } - } - final_result - } - - async fn handle_system_pause(&mut self) -> HostResult<()> { - info!("System pausing"); - - let mut final_result = Ok(()); - - self.pending_tasks.lock().await.clear(); - - if let Err(e) = self.cancel_all_running_tasks().await { - final_result = final_result.and(Err(e)); - } - - if let Err(e) = self.cancel_all_aggregation_tasks().await { - final_result = final_result.and(Err(e)); - } - - // TODO(Kero): make sure all tasks are saved to database, including pending tasks. - - final_result - } -} - -pub async fn handle_proof( - proof_request: &ProofRequest, - opts: &Opts, - chain_specs: &SupportedChainSpecs, - store: Option<&mut TaskManagerWrapperImpl>, -) -> HostResult { - info!( - "Generating proof for block {} on {}", - proof_request.block_number, proof_request.network - ); - - // Check for a cached input for the given request config. - let cached_input = cache::get_input( - &opts.cache_path, - proof_request.block_number, - &proof_request.network.to_string(), - ); - - let l1_chain_spec = chain_specs - .get_chain_spec(&proof_request.l1_network.to_string()) - .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported l1 network".to_string()))?; - - let taiko_chain_spec = chain_specs - .get_chain_spec(&proof_request.network.to_string()) - .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported raiko network".to_string()))?; - - // Execute the proof generation. - let total_time = Measurement::start("", false); - - let raiko = Raiko::new( - l1_chain_spec.clone(), - taiko_chain_spec.clone(), - proof_request.clone(), - ); - let provider = RpcBlockDataProvider::new( - &taiko_chain_spec.rpc.clone(), - proof_request.block_number - 1, - )?; - let input = match cache::validate_input(cached_input, &provider).await { - Ok(cache_input) => cache_input, - Err(_) => { - // no valid cache - memory::reset_stats(); - let measurement = Measurement::start("Generating input...", false); - let input = raiko.generate_input(provider).await?; - let input_time = measurement.stop_with("=> Input generated"); - observe_prepare_input_time(proof_request.block_number, input_time, true); - memory::print_stats("Input generation peak memory used: "); - input - } - }; - memory::reset_stats(); - let output = raiko.get_output(&input)?; - memory::print_stats("Guest program peak memory used: "); - - memory::reset_stats(); - let measurement = Measurement::start("Generating proof...", false); - let proof = raiko - .prove(input.clone(), &output, store.map(|s| s as &mut dyn IdWrite)) - .await - .map_err(|e| { - let total_time = total_time.stop_with("====> Proof generation failed"); - observe_total_time(proof_request.block_number, total_time, false); - match e { - RaikoError::Guest(e) => { - inc_guest_error(&proof_request.proof_type, proof_request.block_number); - HostError::Core(e.into()) - } - e => { - inc_host_error(proof_request.block_number); - e.into() - } - } - })?; - let guest_time = measurement.stop_with("=> Proof generated"); - observe_guest_time( - &proof_request.proof_type, - proof_request.block_number, - guest_time, - true, - ); - memory::print_stats("Prover peak memory used: "); - - inc_guest_success(&proof_request.proof_type, proof_request.block_number); - let total_time = total_time.stop_with("====> Complete proof generated"); - observe_total_time(proof_request.block_number, total_time, true); - - // Cache the input for future use. - cache::set_input( - &opts.cache_path, - proof_request.block_number, - &proof_request.network.to_string(), - &input, - )?; - - Ok(proof) -} diff --git a/host/tests/common/request.rs b/host/tests/common/request.rs index d807587a1..8617bfe4d 100644 --- a/host/tests/common/request.rs +++ b/host/tests/common/request.rs @@ -3,7 +3,7 @@ use raiko_host::server::api; use raiko_lib::consts::Network; use raiko_lib::proof_type::ProofType; use raiko_lib::prover::Proof; -use raiko_tasks::{TaskDescriptor, TaskReport, TaskStatus}; +use raiko_tasks::{AggregationTaskDescriptor, TaskDescriptor, TaskReport, TaskStatus}; use serde_json::json; use crate::common::Client; @@ -246,7 +246,11 @@ pub async fn get_status_of_aggregation_proof_request( client: &Client, request: &AggregationOnlyRequest, ) -> TaskStatus { - let expected_task_descriptor: TaskDescriptor = TaskDescriptor::Aggregation(request.into()); + let descriptor = AggregationTaskDescriptor { + aggregation_ids: request.aggregation_ids.clone(), + proof_type: request.proof_type.clone().map(|p| p.to_string()), + }; + let expected_task_descriptor: TaskDescriptor = TaskDescriptor::Aggregation(descriptor); let report = v2_assert_report(client).await; for (task_descriptor, task_status) in &report { if task_descriptor == &expected_task_descriptor { diff --git a/reqactor/src/actor.rs b/reqactor/src/actor.rs index 9085a4608..74e79d195 100644 --- a/reqactor/src/actor.rs +++ b/reqactor/src/actor.rs @@ -116,8 +116,8 @@ mod tests { proof_type::ProofType, }; use raiko_reqpool::{ - memory_pool, Pool, RequestEntity, RequestKey, SingleProofRequestEntity, - SingleProofRequestKey, StatusWithContext, + memory_pool, RequestEntity, RequestKey, SingleProofRequestEntity, SingleProofRequestKey, + StatusWithContext, }; use std::collections::HashMap; use tokio::sync::mpsc; diff --git a/script/build.sh b/script/build.sh index a14a3d70c..3ea21d9d0 100755 --- a/script/build.sh +++ b/script/build.sh @@ -33,8 +33,6 @@ else echo "Warning: in debug mode" fi -TASKDB=${TASKDB:-raiko-tasks/in-memory} - if [ -z "${RUN}" ]; then COMMAND=build else @@ -49,22 +47,22 @@ fi # NATIVE if [ -z "$1" ] || [ "$1" == "native" ]; then if [ -n "${CLIPPY}" ]; then - cargo clippy -F ${TASKDB} -- -D warnings + cargo clippy -- -D warnings elif [ -z "${RUN}" ]; then if [ -z "${TEST}" ]; then echo "Building native prover" - cargo build ${FLAGS} -F $TASKDB + cargo build ${FLAGS} else echo "Building native tests" - cargo test ${FLAGS} --no-run -F $TASKDB + cargo test ${FLAGS} --no-run fi else if [ -z "${TEST}" ]; then echo "Running native prover" - cargo run ${FLAGS} -F $TASKDB + cargo run ${FLAGS} else echo "Running native tests" - cargo test ${FLAGS} -F $TASKDB + cargo test ${FLAGS} fi fi fi @@ -77,22 +75,22 @@ if [ "$1" == "sgx" ]; then echo "SGX_DIRECT is set to $SGX_DIRECT" fi if [ -n "${CLIPPY}" ]; then - cargo ${TOOLCHAIN_SGX} clippy -p raiko-host -p sgx-prover -F "sgx enable" -F $TASKDB -- -D warnings + cargo ${TOOLCHAIN_SGX} clippy -p raiko-host -p sgx-prover -F "sgx enable" -- -D warnings elif [ -z "${RUN}" ]; then if [ -z "${TEST}" ]; then echo "Building SGX prover" - cargo ${TOOLCHAIN_SGX} build ${FLAGS} --features sgx -F $TASKDB + cargo ${TOOLCHAIN_SGX} build ${FLAGS} --features sgx else echo "Building SGX tests" - cargo ${TOOLCHAIN_SGX} test ${FLAGS} -p raiko-host -p sgx-prover --features "sgx enable" -F $TASKDB --no-run + cargo ${TOOLCHAIN_SGX} test ${FLAGS} -p raiko-host -p sgx-prover --features "sgx enable" --no-run fi else if [ -z "${TEST}" ]; then echo "Running SGX prover" - cargo ${TOOLCHAIN_SGX} run ${FLAGS} --features sgx -F $TASKDB + cargo ${TOOLCHAIN_SGX} run ${FLAGS} --features sgx else echo "Running SGX tests" - cargo ${TOOLCHAIN_SGX} test ${FLAGS} -p raiko-host -p sgx-prover --features "sgx enable" -F $TASKDB + cargo ${TOOLCHAIN_SGX} test ${FLAGS} -p raiko-host -p sgx-prover --features "sgx enable" fi fi fi @@ -109,27 +107,27 @@ if [ "$1" == "risc0" ]; then MOCK=1 RISC0_DEV_MODE=1 CI=1 - cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder -F $TASKDB - cargo ${TOOLCHAIN_RISC0} clippy -F risc0 -F $TASKDB + cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder + cargo ${TOOLCHAIN_RISC0} clippy -F risc0 elif [ -z "${RUN}" ]; then if [ -z "${TEST}" ]; then echo "Building Risc0 prover" - cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder -F $TASKDB + cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder else echo "Building test elfs for Risc0 prover" - cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder --features test,bench -F $TASKDB + cargo ${TOOLCHAIN_RISC0} run --bin risc0-builder --features test,bench fi if [ -z "${GUEST}" ]; then - cargo ${TOOLCHAIN_RISC0} build ${FLAGS} --features risc0 -F $TASKDB + cargo ${TOOLCHAIN_RISC0} build ${FLAGS} --features risc0 fi else if [ -z "${TEST}" ]; then echo "Running Risc0 prover" - cargo ${TOOLCHAIN_RISC0} run ${FLAGS} --features risc0 -F $TASKDB + cargo ${TOOLCHAIN_RISC0} run ${FLAGS} --features risc0 else echo "Running Risc0 tests" - cargo ${TOOLCHAIN_RISC0} test ${FLAGS} --lib risc0-driver --features risc0 -F $TASKDB -- run_unittest_elf - cargo ${TOOLCHAIN_RISC0} test ${FLAGS} -p raiko-host -p risc0-driver --features "risc0 enable" -F $TASKDB + cargo ${TOOLCHAIN_RISC0} test ${FLAGS} --lib risc0-driver --features risc0 -- run_unittest_elf + cargo ${TOOLCHAIN_RISC0} test ${FLAGS} -p raiko-host -p risc0-driver --features "risc0 enable" fi fi fi @@ -142,31 +140,31 @@ if [ "$1" == "sp1" ]; then echo "SP1_PROVER is set to $SP1_PROVER" fi if [ -n "${CLIPPY}" ]; then - cargo ${TOOLCHAIN_SP1} clippy -p raiko-host -p sp1-builder -p sp1-driver -F "sp1,enable" -F $TASKDB + cargo ${TOOLCHAIN_SP1} clippy -p raiko-host -p sp1-builder -p sp1-driver -F "sp1,enable" elif [ -z "${RUN}" ]; then if [ -z "${TEST}" ]; then echo "Building Sp1 prover" - cargo ${TOOLCHAIN_SP1} run --bin sp1-builder -F $TASKDB + cargo ${TOOLCHAIN_SP1} run --bin sp1-builder else echo "Building test elfs for Sp1 prover" - cargo ${TOOLCHAIN_SP1} run --bin sp1-builder --features test,bench -F $TASKDB + cargo ${TOOLCHAIN_SP1} run --bin sp1-builder --features test,bench fi if [ -z "${GUEST}" ]; then - echo "Building 'cargo ${TOOLCHAIN_SP1} build ${FLAGS} --features sp1 -F $TASKDB'" - cargo ${TOOLCHAIN_SP1} build ${FLAGS} --features sp1 -F $TASKDB + echo "Building 'cargo ${TOOLCHAIN_SP1} build ${FLAGS} --features sp1'" + cargo ${TOOLCHAIN_SP1} build ${FLAGS} --features sp1 fi else if [ -z "${TEST}" ]; then echo "Running Sp1 prover" - cargo ${TOOLCHAIN_SP1} run ${FLAGS} --features sp1 -F $TASKDB + cargo ${TOOLCHAIN_SP1} run ${FLAGS} --features sp1 else echo "Running Sp1 unit tests" - cargo ${TOOLCHAIN_SP1} test ${FLAGS} --lib sp1-driver --features sp1 -F $TASKDB -- run_unittest_elf - cargo ${TOOLCHAIN_SP1} test ${FLAGS} -p raiko-host -p sp1-driver --features "sp1 enable" -F $TASKDB + cargo ${TOOLCHAIN_SP1} test ${FLAGS} --lib sp1-driver --features sp1 -- run_unittest_elf + cargo ${TOOLCHAIN_SP1} test ${FLAGS} -p raiko-host -p sp1-driver --features "sp1 enable" # Don't wannt to span Succinct Network and wait 2 hours in CI # echo "Running Sp1 verification" - # cargo ${TOOLCHAIN_SP1} run ${FLAGS} --bin sp1-verifier --features enable,sp1-verifier -F $TASKDB + # cargo ${TOOLCHAIN_SP1} run ${FLAGS} --bin sp1-verifier --features enable,sp1-verifier fi fi fi diff --git a/script/generate-docs.sh b/script/generate-docs.sh index 138ab9c25..219205ee6 100755 --- a/script/generate-docs.sh +++ b/script/generate-docs.sh @@ -2,9 +2,7 @@ DIR="$(cd "$(dirname "$0")" && pwd)" -TASKDB=${TASKDB:-raiko-tasks/in-memory} - cd $DIR mkdir ../openapi -cargo run -F ${TASKDB} --bin docs >../openapi/index.html +cargo run --bin docs >../openapi/index.html diff --git a/script/publish-image.sh b/script/publish-image.sh index 12e2e706d..7cad6d2e2 100755 --- a/script/publish-image.sh +++ b/script/publish-image.sh @@ -13,15 +13,12 @@ if [[ -z "$tag" ]]; then tag="latest" fi -TASKDB=${TASKDB:-raiko-tasks/in-memory} - echo "Build and push $1:$tag..." docker buildx build ./ \ --load \ --platform linux/amd64 \ -t raiko:$tag \ $build_flags \ - --build-arg TASKDB=${TASKDB} \ --build-arg TARGETPLATFORM=linux/amd64 \ --progress=plain diff --git a/taskdb/Cargo.toml b/taskdb/Cargo.toml index cb9883da0..f22dd831b 100644 --- a/taskdb/Cargo.toml +++ b/taskdb/Cargo.toml @@ -7,27 +7,5 @@ edition = "2021" [dependencies] raiko-lib = { workspace = true } raiko-core = { workspace = true } -num_enum = { workspace = true } -chrono = { workspace = true, features = ["serde"] } -thiserror = { workspace = true } serde = { workspace = true } -serde_json = { workspace = true } -hex = { workspace = true } -tracing = { workspace = true } -anyhow = { workspace = true } -tokio = { workspace = true } -async-trait = { workspace = true } utoipa = { workspace = true } -redis = { workspace = true, optional = true } -backoff = { workspace = true } - -[dev-dependencies] -rand = "0.9.0-alpha.1" # This is an alpha version, that has rng.gen_iter::() -rand_chacha = "0.9.0-alpha.1" -tempfile = "3.10.1" -alloy-primitives = { workspace = true, features = ["getrandom"] } - -[features] -default = [] -in-memory = [] -redis-db = ["redis"] diff --git a/taskdb/src/lib.rs b/taskdb/src/lib.rs index 73a3bd33d..2f1b27a68 100644 --- a/taskdb/src/lib.rs +++ b/taskdb/src/lib.rs @@ -1,55 +1,11 @@ -use std::io::{Error as IOError, ErrorKind as IOErrorKind}; - -use chrono::{DateTime, Utc}; use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::{ primitives::{ChainId, B256}, proof_type::ProofType, - prover::{IdStore, IdWrite, ProofKey, ProverResult}, }; use serde::{Deserialize, Serialize}; -use tracing::debug; use utoipa::ToSchema; -#[cfg(feature = "in-memory")] -use crate::mem_db::InMemoryTaskManager; -#[cfg(feature = "redis-db")] -use crate::redis_db::RedisTaskManager; - -#[cfg(feature = "in-memory")] -mod mem_db; -#[cfg(feature = "redis-db")] -mod redis_db; - -// Types -// ---------------------------------------------------------------- -#[derive(Debug, thiserror::Error)] -pub enum TaskManagerError { - #[error("IO Error {0}")] - IOError(IOErrorKind), - #[cfg(feature = "redis-db")] - #[error("Redis Error {0}")] - RedisError(#[from] crate::redis_db::RedisDbError), - #[error("No data for query")] - NoData, - #[error("Anyhow error: {0}")] - Anyhow(String), -} - -pub type TaskManagerResult = Result; - -impl From for TaskManagerError { - fn from(error: IOError) -> TaskManagerError { - TaskManagerError::IOError(error.kind()) - } -} - -impl From for TaskManagerError { - fn from(value: anyhow::Error) -> Self { - TaskManagerError::Anyhow(value.to_string()) - } -} - #[allow(non_camel_case_types)] #[rustfmt::skip] #[derive(PartialEq, Debug, Clone, Deserialize, Serialize, ToSchema, Eq, PartialOrd, Ord)] @@ -149,38 +105,6 @@ pub struct ProofTaskDescriptor { pub prover: String, } -impl From<(ChainId, u64, B256, ProofType, String)> for ProofTaskDescriptor { - fn from( - (chain_id, block_id, blockhash, proof_system, prover): ( - ChainId, - u64, - B256, - ProofType, - String, - ), - ) -> Self { - ProofTaskDescriptor { - chain_id, - block_id, - blockhash, - proof_system, - prover, - } - } -} - -impl From for (ChainId, B256) { - fn from( - ProofTaskDescriptor { - chain_id, - blockhash, - .. - }: ProofTaskDescriptor, - ) -> Self { - (chain_id, blockhash) - } -} - #[derive(Default, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] #[serde(default)] /// A request for proof aggregation of multiple proofs. @@ -191,21 +115,6 @@ pub struct AggregationTaskDescriptor { pub proof_type: Option, } -impl From<&AggregationOnlyRequest> for AggregationTaskDescriptor { - fn from(request: &AggregationOnlyRequest) -> Self { - Self { - aggregation_ids: request.aggregation_ids.clone(), - proof_type: request.proof_type.clone().map(|p| p.to_string()), - } - } -} - -/// Task status triplet (status, proof, timestamp). -pub type TaskProvingStatus = (TaskStatus, Option, DateTime); - -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct TaskProvingStatusRecords(pub Vec); - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum TaskDescriptor { SingleProof(ProofTaskDescriptor), @@ -215,218 +124,3 @@ pub enum TaskDescriptor { pub type TaskReport = (TaskDescriptor, TaskStatus); pub type AggregationTaskReport = (AggregationOnlyRequest, TaskStatus); - -#[derive(Debug, Clone, Default)] -pub struct TaskManagerOpts { - pub max_db_size: usize, - pub redis_url: String, - pub redis_ttl: u64, -} - -#[async_trait::async_trait] -pub trait TaskManager: IdStore + IdWrite + Send + Sync { - /// Create a new task manager. - fn new(opts: &TaskManagerOpts) -> Self; - - /// Enqueue a new task to the tasks database. - async fn enqueue_task( - &mut self, - request: &ProofTaskDescriptor, - ) -> TaskManagerResult; - - /// Update a specific tasks progress. - async fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()>; - - /// Returns the latest triplet (status, proof - if any, last update time). - async fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult; - - /// Returns the proof for the given task. - async fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult>; - - /// Returns the total and detailed database size. - async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)>; - - /// Prune old tasks. - async fn prune_db(&mut self) -> TaskManagerResult<()>; - - /// List all tasks in the db. - async fn list_all_tasks(&mut self) -> TaskManagerResult>; - - /// List all stored ids. - async fn list_stored_ids(&mut self) -> TaskManagerResult>; - - /// Enqueue a new aggregation task to the tasks database. - async fn enqueue_aggregation_task( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()>; - - /// Update a specific aggregation tasks progress. - async fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()>; - - /// Returns the latest triplet (status, proof - if any, last update time). - async fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult; - - /// Returns the proof for the given aggregation task. - async fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult>; - - /// Prune old tasks. - async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()>; - - /// List all tasks in the db. - async fn list_all_aggregation_tasks(&mut self) - -> TaskManagerResult>; -} - -pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { - if !expression { - return Err(TaskManagerError::Anyhow(message.to_string())); - } - Ok(()) -} - -pub struct TaskManagerWrapper { - manager: T, -} - -#[async_trait::async_trait] -impl IdWrite for TaskManagerWrapper { - async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - self.manager.store_id(key, id).await - } - - async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - self.manager.remove_id(key).await - } -} - -#[async_trait::async_trait] -impl IdStore for TaskManagerWrapper { - async fn read_id(&mut self, key: ProofKey) -> ProverResult { - self.manager.read_id(key).await - } -} - -#[async_trait::async_trait] -impl TaskManager for TaskManagerWrapper { - fn new(opts: &TaskManagerOpts) -> Self { - let manager = T::new(opts); - Self { manager } - } - - async fn enqueue_task( - &mut self, - request: &ProofTaskDescriptor, - ) -> TaskManagerResult { - self.manager.enqueue_task(request).await - } - - async fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - self.manager.update_task_progress(key, status, proof).await - } - - async fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult { - self.manager.get_task_proving_status(key).await - } - - async fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult> { - self.manager.get_task_proof(key).await - } - - async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - self.manager.get_db_size().await - } - - async fn prune_db(&mut self) -> TaskManagerResult<()> { - self.manager.prune_db().await - } - - async fn list_all_tasks(&mut self) -> TaskManagerResult> { - self.manager.list_all_tasks().await - } - - async fn list_stored_ids(&mut self) -> TaskManagerResult> { - self.manager.list_stored_ids().await - } - - async fn enqueue_aggregation_task( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()> { - self.manager.enqueue_aggregation_task(request).await - } - - async fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - self.manager - .update_aggregation_task_progress(request, status, proof) - .await - } - - async fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult { - self.manager - .get_aggregation_task_proving_status(request) - .await - } - - async fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult> { - self.manager.get_aggregation_task_proof(request).await - } - - async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { - self.manager.prune_aggregation_db().await - } - - async fn list_all_aggregation_tasks( - &mut self, - ) -> TaskManagerResult> { - self.manager.list_all_aggregation_tasks().await - } -} - -#[cfg(feature = "in-memory")] -pub type TaskManagerWrapperImpl = TaskManagerWrapper; -#[cfg(feature = "redis-db")] -pub type TaskManagerWrapperImpl = TaskManagerWrapper; - -pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapperImpl { - debug!("get task manager with options: {:?}", opts); - TaskManagerWrapperImpl::new(opts) -} diff --git a/taskdb/src/mem_db.rs b/taskdb/src/mem_db.rs deleted file mode 100644 index a16b6735c..000000000 --- a/taskdb/src/mem_db.rs +++ /dev/null @@ -1,424 +0,0 @@ -// Raiko -// Copyright (c) 2024 Taiko Labs -// Licensed and distributed under either of -// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -// at your option. This file may not be copied, modified, or distributed except according to those terms. - -// Imports -// ---------------------------------------------------------------- -use std::{ - collections::HashMap, - sync::{Arc, Once}, -}; - -use chrono::Utc; -use raiko_core::interfaces::AggregationOnlyRequest; -use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; -use tokio::sync::Mutex; -use tracing::{info, warn}; - -use crate::{ - ensure, AggregationTaskDescriptor, AggregationTaskReport, ProofTaskDescriptor, TaskDescriptor, - TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatusRecords, - TaskReport, TaskStatus, -}; - -#[derive(Debug)] -pub struct InMemoryTaskManager { - db: Arc>, -} - -#[derive(Debug)] -pub struct InMemoryTaskDb { - tasks_queue: HashMap, - aggregation_tasks_queue: HashMap, - store: HashMap, -} - -impl InMemoryTaskDb { - fn new() -> InMemoryTaskDb { - InMemoryTaskDb { - tasks_queue: HashMap::new(), - aggregation_tasks_queue: HashMap::new(), - store: HashMap::new(), - } - } - - fn enqueue_task(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult<()> { - let task_status = (TaskStatus::Registered, None, Utc::now()); - - match self.tasks_queue.get(key) { - Some(task_proving_records) => { - let previous_status = &task_proving_records.0.last().unwrap().0; - warn!("Task already exists: {key:?} with previous statuw {previous_status:?}"); - if previous_status != &TaskStatus::Success { - self.update_task_progress(key.clone(), TaskStatus::Registered, None)?; - } - } // do nothing - None => { - info!("Enqueue new task: {key:?}"); - self.tasks_queue - .insert(key.clone(), TaskProvingStatusRecords(vec![task_status])); - } - } - - Ok(()) - } - - fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - ensure(self.tasks_queue.contains_key(&key), "no task found")?; - - self.tasks_queue.entry(key).and_modify(|entry| { - if let Some(latest) = entry.0.last() { - if latest.0 != status { - entry.0.push((status, proof.map(hex::encode), Utc::now())); - } - } - }); - - Ok(()) - } - - fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult { - Ok(self.tasks_queue.get(key).cloned().unwrap_or_default()) - } - - fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult> { - ensure(self.tasks_queue.contains_key(key), "no task found")?; - - let proving_status_records = self - .tasks_queue - .get(key) - .ok_or_else(|| TaskManagerError::Anyhow("no task in db".to_owned()))?; - - let (_, proof, ..) = proving_status_records - .0 - .iter() - .filter(|(status, ..)| (status == &TaskStatus::Success)) - .last() - .ok_or_else(|| TaskManagerError::Anyhow("no successful task in db".to_owned()))?; - - let Some(proof) = proof else { - return Ok(vec![]); - }; - - hex::decode(proof) - .map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned())) - } - - fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - Ok((self.tasks_queue.len(), vec![])) - } - - fn prune(&mut self) -> TaskManagerResult<()> { - self.tasks_queue.clear(); - Ok(()) - } - - fn list_all_tasks(&mut self) -> TaskManagerResult> { - let single_proofs = self.tasks_queue.iter().filter_map(|(desc, statuses)| { - statuses - .0 - .last() - .map(|s| (TaskDescriptor::SingleProof(desc.clone()), s.0.clone())) - }); - - let aggregations = self - .aggregation_tasks_queue - .iter() - .filter_map(|(desc, statuses)| { - statuses.0.last().map(|s| { - ( - TaskDescriptor::Aggregation(AggregationTaskDescriptor::from(desc)), - s.0.clone(), - ) - }) - }); - - Ok(single_proofs.chain(aggregations).collect()) - } - - fn list_stored_ids(&mut self) -> TaskManagerResult> { - Ok(self.store.iter().map(|(k, v)| (*k, v.clone())).collect()) - } - - fn store_id(&mut self, key: ProofKey, id: String) -> TaskManagerResult<()> { - self.store.insert(key, id); - Ok(()) - } - - fn remove_id(&mut self, key: ProofKey) -> TaskManagerResult<()> { - self.store.remove(&key); - Ok(()) - } - - fn read_id(&mut self, key: ProofKey) -> TaskManagerResult { - self.store - .get(&key) - .cloned() - .ok_or(TaskManagerError::NoData) - } - - fn enqueue_aggregation_task( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()> { - let task_status = (TaskStatus::Registered, None, Utc::now()); - - match self.aggregation_tasks_queue.get(request) { - Some(task_proving_records) => { - let previous_status = &task_proving_records.0.last().unwrap().0; - warn!("Task already exists: {request} with previous status {previous_status:?}"); - if previous_status != &TaskStatus::Success { - self.update_aggregation_task_progress(request, TaskStatus::Registered, None)?; - } - } // do nothing - None => { - info!("Enqueue new task: {request}"); - self.aggregation_tasks_queue - .insert(request.clone(), TaskProvingStatusRecords(vec![task_status])); - } - } - Ok(()) - } - - fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult { - Ok(self - .aggregation_tasks_queue - .get(request) - .cloned() - .unwrap_or_default()) - } - - fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - ensure( - self.aggregation_tasks_queue.contains_key(request), - "no task found", - )?; - - self.aggregation_tasks_queue - .entry(request.clone()) - .and_modify(|entry| { - if let Some(latest) = entry.0.last() { - if latest.0 != status { - entry.0.push((status, proof.map(hex::encode), Utc::now())); - } - } - }); - - Ok(()) - } - - fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult> { - ensure( - self.aggregation_tasks_queue.contains_key(request), - "no task found", - )?; - - let proving_status_records = self - .aggregation_tasks_queue - .get(request) - .ok_or_else(|| TaskManagerError::Anyhow("no task in db".to_owned()))?; - - let (_, proof, ..) = proving_status_records - .0 - .iter() - .filter(|(status, ..)| (status == &TaskStatus::Success)) - .last() - .ok_or_else(|| TaskManagerError::Anyhow("no successful task in db".to_owned()))?; - - let Some(proof) = proof else { - return Ok(vec![]); - }; - - hex::decode(proof) - .map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned())) - } - - fn prune_aggregation(&mut self) -> TaskManagerResult<()> { - self.aggregation_tasks_queue.clear(); - Ok(()) - } - - fn list_all_aggregation_tasks(&mut self) -> TaskManagerResult> { - Ok(self - .aggregation_tasks_queue - .iter() - .flat_map(|(request, statuses)| { - statuses - .0 - .last() - .map(|status| (request.clone(), status.0.clone())) - }) - .collect()) - } -} - -#[async_trait::async_trait] -impl IdWrite for InMemoryTaskManager { - async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - let mut db = self.db.lock().await; - db.store_id(key, id) - .map_err(|e| ProverError::StoreError(e.to_string())) - } - - async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - let mut db = self.db.lock().await; - db.remove_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl IdStore for InMemoryTaskManager { - async fn read_id(&mut self, key: ProofKey) -> ProverResult { - let mut db = self.db.lock().await; - db.read_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl TaskManager for InMemoryTaskManager { - fn new(_opts: &TaskManagerOpts) -> Self { - static INIT: Once = Once::new(); - static mut SHARED_TASK_MANAGER: Option>> = None; - - INIT.call_once(|| { - let task_manager: Arc> = - Arc::new(Mutex::new(InMemoryTaskDb::new())); - unsafe { - SHARED_TASK_MANAGER = Some(Arc::clone(&task_manager)); - } - }); - - InMemoryTaskManager { - db: unsafe { SHARED_TASK_MANAGER.clone().unwrap() }, - } - } - - async fn enqueue_task( - &mut self, - params: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let mut db = self.db.lock().await; - let status = db.get_task_proving_status(params)?; - if !status.0.is_empty() { - return Ok(status); - } - - db.enqueue_task(params)?; - db.get_task_proving_status(params) - } - - async fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let mut db = self.db.lock().await; - db.update_task_progress(key, status, proof) - } - - /// Returns the latest triplet (submitter or fulfiller, status, last update time) - async fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let mut db = self.db.lock().await; - db.get_task_proving_status(key) - } - - async fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult> { - let mut db = self.db.lock().await; - db.get_task_proof(key) - } - - /// Returns the total and detailed database size - async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - let mut db = self.db.lock().await; - db.size() - } - - async fn prune_db(&mut self) -> TaskManagerResult<()> { - let mut db = self.db.lock().await; - db.prune() - } - - async fn list_all_tasks(&mut self) -> TaskManagerResult> { - let mut db = self.db.lock().await; - db.list_all_tasks() - } - - async fn list_stored_ids(&mut self) -> TaskManagerResult> { - let mut db = self.db.lock().await; - db.list_stored_ids() - } - - async fn enqueue_aggregation_task( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()> { - let mut db = self.db.lock().await; - db.enqueue_aggregation_task(request) - } - - async fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult { - let mut db = self.db.lock().await; - db.get_aggregation_task_proving_status(request) - } - - async fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let mut db = self.db.lock().await; - db.update_aggregation_task_progress(request, status, proof) - } - - async fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult> { - let mut db = self.db.lock().await; - db.get_aggregation_task_proof(request) - } - - async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { - let mut db = self.db.lock().await; - db.prune_aggregation() - } - - async fn list_all_aggregation_tasks( - &mut self, - ) -> TaskManagerResult> { - let mut db = self.db.lock().await; - db.list_all_aggregation_tasks() - } -} diff --git a/taskdb/src/redis_db.rs b/taskdb/src/redis_db.rs deleted file mode 100644 index 91f9eec75..000000000 --- a/taskdb/src/redis_db.rs +++ /dev/null @@ -1,821 +0,0 @@ -#![cfg(feature = "redis-db")] -// Raiko -// Copyright (c) 2024 Taiko Labs -// Licensed and distributed under either of -// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -// at your option. This file may not be copied, modified, or distributed except according to those terms. - -// Imports -// ---------------------------------------------------------------- -use backoff::ExponentialBackoff; -use chrono::Utc; -use raiko_core::interfaces::AggregationOnlyRequest; -use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; -use redis::{ - Client, Commands, ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, - Value, -}; -use std::sync::{Arc, Once}; -use std::time::Duration; -use thiserror::Error; -use tokio::sync::Mutex; -use tracing::{error, info, warn}; - -use crate::{ - AggregationTaskDescriptor, AggregationTaskReport, ProofTaskDescriptor, TaskDescriptor, - TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatus, - TaskProvingStatusRecords, TaskReport, TaskStatus, -}; - -pub struct RedisTaskDb { - client: Client, - config: RedisConfig, -} - -pub struct RedisTaskManager { - arc_task_db: Arc>, -} - -type RedisDbResult = Result; - -#[derive(Error, Debug)] -pub enum RedisDbError { - #[error("Redis DB error: {0}")] - RedisDb(#[from] RedisError), - #[error("Redis Task Manager error: {0}")] - TaskManager(String), - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - #[error("Redis key non-exist: {0}")] - KeyNotFound(String), -} - -impl ToRedisArgs for ProofTaskDescriptor { - fn write_redis_args(&self, out: &mut W) - where - W: ?Sized + RedisWrite, - { - let serialized = serde_json::to_string(self).expect("Failed to serialize TaskDescriptor"); - out.write_arg(serialized.as_bytes()); - } -} - -impl FromRedisValue for ProofTaskDescriptor { - fn from_redis_value(v: &Value) -> RedisResult { - let serialized = String::from_redis_value(v)?; - serde_json::from_str(&serialized).map_err(|_| { - RedisError::from(( - ErrorKind::TypeError, - "ProofTaskDescriptor type conversion fail", - )) - }) - } -} - -impl ToRedisArgs for AggregationTaskDescriptor { - fn write_redis_args(&self, out: &mut W) - where - W: ?Sized + RedisWrite, - { - let serialized = - serde_json::to_string(self).expect("Failed to serialize AggregationTaskDescriptor"); - out.write_arg(serialized.as_bytes()); - } -} - -impl FromRedisValue for AggregationTaskDescriptor { - fn from_redis_value(v: &Value) -> RedisResult { - let serialized = String::from_redis_value(v)?; - serde_json::from_str(&serialized).map_err(|_| { - RedisError::from(( - ErrorKind::TypeError, - "AggregationTaskDescriptor type conversion fail", - )) - }) - } -} - -impl ToRedisArgs for TaskProvingStatusRecords { - fn write_redis_args(&self, out: &mut W) - where - W: ?Sized + RedisWrite, - { - let serialized = - serde_json::to_string(self).expect("Failed to serialize TaskProvingStatusRecords"); - out.write_arg(serialized.as_bytes()); - } -} - -impl FromRedisValue for TaskProvingStatusRecords { - fn from_redis_value(v: &Value) -> RedisResult { - let serialized = String::from_redis_value(v)?; - serde_json::from_str(&serialized).map_err(|_| { - RedisError::from(( - ErrorKind::TypeError, - "TaskProvingStatusRecords type conversion fail", - )) - }) - } -} - -struct TaskIdDescriptor(ProofKey); - -impl ToRedisArgs for TaskIdDescriptor { - fn write_redis_args(&self, out: &mut W) - where - W: ?Sized + RedisWrite, - { - let serialized = - serde_json::to_string(&self.0).expect("Failed to serialize TaskIDDescriptor"); - out.write_arg(serialized.as_bytes()); - } -} - -impl FromRedisValue for TaskIdDescriptor { - fn from_redis_value(v: &Value) -> RedisResult { - let serialized = String::from_redis_value(v)?; - let proof_key = serde_json::from_str(&serialized).map_err(|_| { - RedisError::from(( - ErrorKind::TypeError, - "TaskIdDescriptor type conversion fail", - )) - })?; - Ok(TaskIdDescriptor(proof_key)) - } -} - -#[derive(Debug, Clone, Default)] -pub struct RedisConfig { - url: String, - ttl: u64, -} - -impl RedisTaskDb { - fn new(config: RedisConfig) -> RedisDbResult { - let url = config.url.clone(); - let client = Client::open(url).map_err(RedisDbError::RedisDb)?; - Ok(RedisTaskDb { client, config }) - } - - fn get_conn(&mut self) -> Result { - let backoff = ExponentialBackoff { - initial_interval: Duration::from_secs(10), - max_interval: Duration::from_secs(60), - max_elapsed_time: Some(Duration::from_secs(300)), - ..Default::default() - }; - - backoff::retry(backoff, || match self.client.get_connection() { - Ok(conn) => Ok(conn), - Err(e) => { - error!("Failed to connect to redis: {e:?}, retrying..."); - self.client = redis::Client::open(self.config.url.clone())?; - Err(backoff::Error::Transient { - err: e, - retry_after: None, - }) - } - }) - .map_err(|e| match e { - backoff::Error::Transient { - err, - retry_after: _, - } - | backoff::Error::Permanent(err) => err, - }) - } - - fn insert_proof_task( - &mut self, - key: &ProofTaskDescriptor, - value: &TaskProvingStatusRecords, - ) -> RedisDbResult<()> { - self.insert_redis(key, value) - } - - fn insert_aggregation_task( - &mut self, - key: &AggregationTaskDescriptor, - value: &TaskProvingStatusRecords, - ) -> RedisDbResult<()> { - self.insert_redis(key, value) - } - - fn insert_redis(&mut self, key: &K, value: &V) -> RedisDbResult<()> - where - K: ToRedisArgs, - V: ToRedisArgs, - { - self.get_conn()? - .set_ex(key, value, self.config.ttl) - .map_err(RedisDbError::RedisDb)?; - Ok(()) - } - - fn query_proof_task( - &mut self, - key: &ProofTaskDescriptor, - ) -> RedisDbResult> { - match self.query_redis(&key) { - Ok(Some(v)) => { - if let Some(records) = serde_json::from_str(&v)? { - Ok(Some(records)) - } else { - error!("Failed to deserialize TaskProvingStatusRecords"); - Err(RedisDbError::TaskManager( - format!("Failed to deserialize TaskProvingStatusRecords").to_owned(), - )) - } - } - Ok(None) => Ok(None), - Err(e) => Err(e), - } - } - - fn query_proof_task_latest_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> RedisDbResult> { - self.query_proof_task(key) - .map(|v| v.map(|records| records.0.last().unwrap().clone())) - } - - fn query_aggregation_task( - &mut self, - key: &AggregationTaskDescriptor, - ) -> RedisDbResult> { - match self.query_redis(&key) { - Ok(Some(v)) => Ok(Some(serde_json::from_str(&v)?)), - Ok(None) => Ok(None), - Err(e) => Err(e), - } - } - - fn query_aggregation_task_latest_status( - &mut self, - key: &AggregationTaskDescriptor, - ) -> RedisDbResult> { - self.query_aggregation_task(key) - .map(|v| v.map(|records| records.0.last().unwrap().clone())) - } - - fn query_redis(&mut self, key: &impl ToRedisArgs) -> RedisDbResult> { - match self.get_conn()?.get(key) { - Ok(value) => Ok(Some(value)), - Err(e) if e.kind() == redis::ErrorKind::TypeError => Ok(None), - Err(e) => Err(RedisDbError::RedisDb(e)), - } - } - - fn delete_redis(&mut self, key: &impl ToRedisArgs) -> RedisDbResult<()> { - let result: i32 = self.get_conn()?.del(key).map_err(RedisDbError::RedisDb)?; - if result != 1 { - return Err(RedisDbError::TaskManager("redis del".to_owned())); - } - Ok(()) - } - - fn update_proof_task_status( - &mut self, - key: &ProofTaskDescriptor, - new_status: TaskProvingStatus, - ) -> RedisDbResult<()> { - let old_value = self.query_proof_task(key).unwrap_or_default(); - let mut records = match old_value { - Some(v) => v, - None => { - warn!("Update a unknown task: {key:?} to {new_status:?}"); - TaskProvingStatusRecords(vec![]) - } - }; - - records.0.push(new_status); - let k = serde_json::to_string(&key)?; - let v = serde_json::to_string(&records)?; - - self.update_status_redis(&k, &v) - } - - fn update_aggregation_status( - &mut self, - key: &AggregationTaskDescriptor, - new_status: TaskProvingStatus, - ) -> RedisDbResult<()> { - let old_value = self.query_aggregation_task(key.into()).unwrap_or_default(); - let mut records = match old_value { - Some(v) => v, - None => { - warn!("Update a unknown task: {key:?} to {new_status:?}"); - TaskProvingStatusRecords(vec![]) - } - }; - - records.0.push(new_status); - let k = serde_json::to_string(&key)?; - let v = serde_json::to_string(&records)?; - - self.update_status_redis(&k, &v) - } - - fn update_status_redis(&mut self, k: &String, v: &String) -> RedisDbResult<()> { - self.get_conn()?.set_ex(k, v, self.config.ttl)?; - Ok(()) - } -} - -impl RedisTaskDb { - fn enqueue_task(&mut self, key: &ProofTaskDescriptor) -> RedisDbResult { - let task_status = (TaskStatus::Registered, None, Utc::now()); - - match self.query_proof_task(key) { - Ok(Some(task_proving_records)) => { - warn!( - "Task status exists: {:?}, register again", - task_proving_records.0.last() - ); - self.insert_proof_task(key, &TaskProvingStatusRecords(vec![task_status.clone()]))?; - Ok(task_status) - } // do nothing - Ok(None) => { - info!("Enqueue new task: {key:?}"); - self.insert_proof_task(key, &TaskProvingStatusRecords(vec![task_status.clone()]))?; - Ok(task_status) - } - Err(e) => { - error!("Enqueue task failed: {e:?}"); - Err(e) - } - } - } - - fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> RedisDbResult<()> { - match self.query_proof_task(&key) { - Ok(Some(records)) => { - if let Some(latest) = records.0.last() { - if latest.0 != status { - let new_statue = (status, proof.map(hex::encode), Utc::now()); - self.update_proof_task_status(&key, new_statue)?; - } - } else { - return Err(RedisDbError::TaskManager( - format!("task {key:?} not found").to_owned(), - )); - } - Ok(()) - } - Ok(None) => Err(RedisDbError::TaskManager( - format!("task {key:?} not found").to_owned(), - )), - Err(e) => Err(RedisDbError::TaskManager( - format!("query {key:?} error: {e:?}").to_owned(), - )), - } - } - - fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> RedisDbResult { - match self.query_proof_task(key) { - Ok(Some(records)) => Ok(records), - Ok(None) => Err(RedisDbError::KeyNotFound( - format!("task {key:?} not found").to_owned(), - )), - Err(e) => Err(RedisDbError::TaskManager( - format!("query {key:?} error: {e:?}").to_owned(), - )), - } - } - - fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> RedisDbResult> { - let proving_status_records = self - .query_proof_task(key) - .map_err(|e| RedisDbError::TaskManager(format!("query error: {e:?}").to_owned()))? - .unwrap_or_default(); - - let (_, proof, ..) = proving_status_records - .0 - .iter() - .filter(|(status, ..)| (status == &TaskStatus::Success)) - .last() - .ok_or_else(|| { - RedisDbError::TaskManager(format!("task {key:?} not success.").to_owned()) - })?; - - if let Some(proof_str) = proof { - hex::decode(proof_str).map_err(|e| { - RedisDbError::TaskManager( - format!("task {key:?} hex decode failed for {e:?}").to_owned(), - ) - }) - } else { - Ok(vec![]) - } - } - - fn prune(&mut self) -> RedisDbResult<()> { - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match ( - ProofTaskDescriptor::from_redis_value(key), - AggregationTaskDescriptor::from_redis_value(key), - ) { - (Ok(desc), _) => { - self.delete_redis(&desc)?; - } - (_, Ok(desc)) => { - self.delete_redis(&desc)?; - } - _ => (), - } - } - - self.prune_stored_ids()?; - Ok(()) - } - - fn list_all_tasks(&mut self) -> RedisDbResult> { - let mut kvs = Vec::new(); - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match ( - ProofTaskDescriptor::from_redis_value(key), - AggregationTaskDescriptor::from_redis_value(key), - ) { - (Ok(desc), _) => { - let status = self.query_proof_task_latest_status(&desc)?; - status.map(|s| kvs.push((TaskDescriptor::SingleProof(desc), s.0))); - } - (_, Ok(desc)) => { - let status = self.query_aggregation_task_latest_status(&desc)?; - status.map(|s| kvs.push((TaskDescriptor::Aggregation(desc), s.0))); - } - _ => (), - } - } - - Ok(kvs) - } - - fn enqueue_aggregation_task(&mut self, request: &AggregationOnlyRequest) -> RedisDbResult<()> { - let task_status = (TaskStatus::Registered, None, Utc::now()); - let agg_task_descriptor = request.into(); - match self.query_aggregation_task(&agg_task_descriptor)? { - Some(task_proving_records) => { - info!( - "Task already exists: {:?}", - task_proving_records.0.last().unwrap().0 - ); - } // do nothing - None => { - info!("Enqueue new aggregation task: {request}"); - self.insert_aggregation_task( - &agg_task_descriptor, - &TaskProvingStatusRecords(vec![task_status]), - )?; - } - } - Ok(()) - } - - fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> RedisDbResult { - let agg_task_descriptor = request.into(); - match self.query_aggregation_task(&agg_task_descriptor)? { - Some(records) => Ok(records), - None => Err(RedisDbError::KeyNotFound( - format!("task {agg_task_descriptor:?} not found").to_owned(), - )), - } - } - - fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> RedisDbResult<()> { - let agg_task_descriptor = request.into(); - match self.query_aggregation_task(&agg_task_descriptor)? { - Some(records) => { - if let Some(latest) = records.0.last() { - if latest.0 != status { - let new_record = (status, proof.map(hex::encode), Utc::now()); - self.update_aggregation_status(&agg_task_descriptor, new_record)?; - } - } else { - return Err(RedisDbError::TaskManager( - format!("task {agg_task_descriptor:?} not found").to_owned(), - )); - } - Ok(()) - } - None => Err(RedisDbError::TaskManager( - format!("task {agg_task_descriptor:?} not found").to_owned(), - )), - } - } - - fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> RedisDbResult> { - let agg_task_descriptor = request.into(); - let proving_status_records = self - .query_aggregation_task(&agg_task_descriptor)? - .unwrap_or_default(); - - let (_, proof, ..) = proving_status_records - .0 - .iter() - .filter(|(status, ..)| (status == &TaskStatus::Success)) - .last() - .ok_or_else(|| { - RedisDbError::TaskManager( - format!("task {agg_task_descriptor:?} not found").to_owned(), - ) - })?; - - if let Some(proof) = proof { - hex::decode(proof).map_err(|e| { - RedisDbError::TaskManager( - format!("task {agg_task_descriptor:?} hex decode failed for {e:?}").to_owned(), - ) - }) - } else { - Ok(vec![]) - } - } - - fn get_db_size(&self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - // todo - Ok((0, vec![])) - } - - fn prune_aggregation(&mut self) -> RedisDbResult<()> { - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match AggregationTaskDescriptor::from_redis_value(key) { - Ok(desc) => { - self.delete_redis(&desc)?; - } - _ => (), - } - } - Ok(()) - } - - fn list_all_aggregation_tasks(&mut self) -> RedisDbResult> { - let mut kvs: Vec = Vec::new(); - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match AggregationTaskDescriptor::from_redis_value(key) { - Ok(desc) => { - let status = self.query_aggregation_task_latest_status(&desc)?; - status.map(|s| { - kvs.push(( - AggregationOnlyRequest { - aggregation_ids: desc.aggregation_ids, - proof_type: desc.proof_type, - ..Default::default() - }, - s.0, - )) - }); - } - _ => (), - } - } - Ok(kvs) - } -} - -impl RedisTaskDb { - fn store_id(&mut self, key: ProofKey, id: String) -> RedisDbResult<()> { - self.insert_redis(&TaskIdDescriptor(key), &id) - } - - fn remove_id(&mut self, key: ProofKey) -> RedisDbResult<()> { - self.delete_redis(&TaskIdDescriptor(key)) - } - - fn read_id(&mut self, key: ProofKey) -> RedisDbResult { - match self.query_redis(&TaskIdDescriptor(key)) { - Ok(Some(v)) => Ok(v), - Ok(None) => Err(RedisDbError::TaskManager( - format!("id {key:?} not found").to_owned(), - )), - Err(e) => Err(RedisDbError::TaskManager( - format!("id {key:?} query error: {e:?}").to_owned(), - )), - } - } - - fn list_stored_ids(&mut self) -> RedisDbResult> { - let mut kvs = Vec::new(); - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match TaskIdDescriptor::from_redis_value(key) { - Ok(desc) => { - let status = self.query_redis(&desc)?; - status.map(|s| kvs.push((desc.0, s))); - } - _ => (), - } - } - Ok(kvs) - } - - fn prune_stored_ids(&mut self) -> RedisDbResult<()> { - let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; - for key in keys.iter() { - match TaskIdDescriptor::from_redis_value(key) { - Ok(desc) => { - self.delete_redis(&desc)?; - } - _ => (), - } - } - Ok(()) - } -} - -#[async_trait::async_trait] -impl IdStore for RedisTaskManager { - async fn read_id(&mut self, key: ProofKey) -> ProverResult { - let mut db = self.arc_task_db.lock().await; - db.read_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl IdWrite for RedisTaskManager { - async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - let mut db = self.arc_task_db.lock().await; - db.store_id(key, id) - .map_err(|e| ProverError::StoreError(e.to_string())) - } - - async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - let mut db = self.arc_task_db.lock().await; - db.remove_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl TaskManager for RedisTaskManager { - fn new(opts: &TaskManagerOpts) -> Self { - static INIT: Once = Once::new(); - static mut REDIS_DB: Option>> = None; - INIT.call_once(|| { - unsafe { - REDIS_DB = Some(Arc::new(Mutex::new({ - let db = RedisTaskDb::new(RedisConfig { - url: opts.redis_url.clone(), - ttl: opts.redis_ttl.clone(), - }) - .unwrap(); - db - }))) - }; - }); - Self { - arc_task_db: unsafe { REDIS_DB.clone().unwrap() }, - } - } - - async fn enqueue_task( - &mut self, - params: &ProofTaskDescriptor, - ) -> Result { - let mut task_db = self.arc_task_db.lock().await; - let enq_status = task_db.enqueue_task(params)?; - Ok(TaskProvingStatusRecords(vec![enq_status])) - } - - async fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let mut task_db = self.arc_task_db.lock().await; - task_db.update_task_progress(key, status, proof)?; - Ok(()) - } - - /// Returns the latest triplet (submitter or fulfiller, status, last update time) - async fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let mut task_db = self.arc_task_db.lock().await; - match task_db.get_task_proving_status(key) { - Ok(records) => Ok(records), - Err(RedisDbError::KeyNotFound(_)) => Ok(TaskProvingStatusRecords(vec![])), - Err(e) => Err(TaskManagerError::RedisError(e)), - } - } - - async fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult> { - let mut task_db = self.arc_task_db.lock().await; - let proof = task_db.get_task_proof(key)?; - Ok(proof) - } - - /// Returns the total and detailed database size - async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - let task_db = self.arc_task_db.lock().await; - let res = task_db.get_db_size()?; - Ok(res) - } - - async fn prune_db(&mut self) -> TaskManagerResult<()> { - let mut task_db = self.arc_task_db.lock().await; - task_db.prune().map_err(TaskManagerError::RedisError) - } - - async fn list_all_tasks(&mut self) -> TaskManagerResult> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .list_all_tasks() - .map_err(TaskManagerError::RedisError) - } - - async fn list_stored_ids(&mut self) -> TaskManagerResult> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .list_stored_ids() - .map_err(TaskManagerError::RedisError) - } - - async fn enqueue_aggregation_task( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .enqueue_aggregation_task(request) - .map_err(TaskManagerError::RedisError) - } - - async fn get_aggregation_task_proving_status( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult { - let mut task_db = self.arc_task_db.lock().await; - match task_db.get_aggregation_task_proving_status(request) { - Ok(records) => Ok(records), - Err(RedisDbError::KeyNotFound(_)) => Ok(TaskProvingStatusRecords(vec![])), - Err(e) => Err(TaskManagerError::RedisError(e)), - } - } - - async fn update_aggregation_task_progress( - &mut self, - request: &AggregationOnlyRequest, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .update_aggregation_task_progress(request, status, proof) - .map_err(TaskManagerError::RedisError) - } - - async fn get_aggregation_task_proof( - &mut self, - request: &AggregationOnlyRequest, - ) -> TaskManagerResult> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .get_aggregation_task_proof(request) - .map_err(TaskManagerError::RedisError) - } - - async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .prune_aggregation() - .map_err(TaskManagerError::RedisError) - } - - async fn list_all_aggregation_tasks( - &mut self, - ) -> TaskManagerResult> { - let mut task_db = self.arc_task_db.lock().await; - task_db - .list_all_aggregation_tasks() - .map_err(TaskManagerError::RedisError) - } -}