diff --git a/.github/workflows/ci-build-test-reusable.yml b/.github/workflows/ci-build-test-reusable.yml index 40c86e43b..443fddef1 100644 --- a/.github/workflows/ci-build-test-reusable.yml +++ b/.github/workflows/ci-build-test-reusable.yml @@ -57,5 +57,7 @@ jobs: - name: Build ${{ inputs.version_name }} prover run: make build + - name: Setup Redis service + uses: shogo82148/actions-setup-redis@v1 - name: Test ${{ inputs.version_name }} prover run: make test diff --git a/.github/workflows/ci-native.yml b/.github/workflows/ci-native.yml index a63cc3256..0a0436429 100644 --- a/.github/workflows/ci-native.yml +++ b/.github/workflows/ci-native.yml @@ -17,17 +17,6 @@ jobs: with: fetch-depth: 0 - - name: Check if specific file changed - id: check_file - run: | - BASE_BRANCH=${{ github.event.pull_request.base.ref }} - if git diff --name-only origin/$BASE_BRANCH ${{ github.sha }} | grep -q "taskdb/src/redis_db.rs"; then - echo "redis changed" - echo "::set-output name=taskdb::raiko-tasks/redis-db" - else - echo "redis unchanged" - echo "::set-output name=taskdb::" - fi build-test-native: name: Build and test native diff --git a/Cargo.lock b/Cargo.lock index 8573722a0..5eded59c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6108,9 +6108,11 @@ dependencies = [ "serde", "serde_json", "serde_with 3.11.0", + "serial_test", "sgx-prover", "sha2", "sp1-driver", + "test-log", "thiserror", "tokio", "tokio-util", @@ -6139,6 +6141,7 @@ dependencies = [ "anyhow", "async-trait", "bincode", + "bytemuck", "cfg-if", "chrono", "flate2", @@ -8389,9 +8392,9 @@ dependencies = [ [[package]] name = "serial_test" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9" dependencies = [ "futures", "log", @@ -8403,9 +8406,9 @@ dependencies = [ [[package]] name = "serial_test_derive" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", @@ -9377,6 +9380,28 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" +[[package]] +name = "test-log" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber 0.3.18", +] + +[[package]] +name = "test-log-macros" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.82", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index 0aad4880e..3ad2f7ecb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -185,6 +185,8 @@ dirs = "5.0.1" pathdiff = "0.2.1" dotenv = "0.15.0" backoff = "0.4.0" +test-log = "0.2.16" +serial_test = "3.2.0" [patch.crates-io] revm = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" } diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 03382f9a0..4c704848c 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -79,6 +79,58 @@ impl From for RaikoError { pub type RaikoResult = Result; +/// Get the proving image for a given proof type. +pub fn get_proving_image(proof_type: ProofType) -> RaikoResult<(&'static [u8], &'static [u32; 8])> { + match proof_type { + ProofType::Native => Ok(NativeProver::current_proving_image()), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return Ok(sp1_driver::Sp1Prover::current_proving_image()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return Ok(risc0_driver::Risc0Prover::current_proving_image()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return Ok(sgx_prover::SgxProver::current_proving_image()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + } +} + +/// Get the aggregation image for a given proof type. +pub fn get_aggregation_image( + proof_type: ProofType, +) -> RaikoResult<(&'static [u8], &'static [u32; 8])> { + match proof_type { + ProofType::Native => Ok(NativeProver::current_aggregation_image()), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return Ok(sp1_driver::Sp1Prover::current_aggregation_image()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return Ok(risc0_driver::Risc0Prover::current_aggregation_image()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return Ok(sgx_prover::SgxProver::current_aggregation_image()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + } +} + /// Run the prover driver depending on the proof type. pub async fn run_prover( proof_type: ProofType, @@ -217,6 +269,8 @@ pub struct ProofRequest { pub proof_type: ProofType, /// Blob proof type. pub blob_proof_type: BlobProofType, + /// The guest image id for RISC0/SP1 provers. + pub image_id: Option, #[serde(flatten)] /// Additional prover params. pub prover_args: HashMap, @@ -251,6 +305,9 @@ pub struct ProofRequestOpt { pub proof_type: Option, /// Blob proof type. pub blob_proof_type: Option, + #[arg(long, require_equals = true)] + /// The guest image id for RISC0/SP1 provers. + pub image_id: Option, #[command(flatten)] #[serde(flatten)] /// Any additional prover params in JSON format. @@ -310,6 +367,24 @@ impl TryFrom for ProofRequest { type Error = RaikoError; fn try_from(value: ProofRequestOpt) -> Result { + let proof_type = value + .proof_type + .as_ref() + .ok_or(RaikoError::InvalidRequestConfig( + "Missing proof_type".to_string(), + ))? + .parse() + .map_err(|_| RaikoError::InvalidRequestConfig("Invalid proof_type".to_string()))?; + + // Check if we need an image ID for this proof type + let image_id = match &proof_type { + ProofType::Risc0 | ProofType::Sp1 => value + .image_id + .clone() + .ok_or_else(|| RaikoError::InvalidRequestConfig("Missing image_id".to_string()))?, + _ => value.image_id.unwrap_or_default(), + }; + Ok(Self { block_number: value.block_number.ok_or(RaikoError::InvalidRequestConfig( "Missing block number".to_string(), @@ -335,13 +410,7 @@ impl TryFrom for ProofRequest { ))? .parse() .map_err(|_| RaikoError::InvalidRequestConfig("Invalid prover".to_string()))?, - proof_type: value - .proof_type - .ok_or(RaikoError::InvalidRequestConfig( - "Missing proof_type".to_string(), - ))? - .parse() - .map_err(|_| RaikoError::InvalidRequestConfig("Invalid proof_type".to_string()))?, + proof_type, blob_proof_type: value .blob_proof_type .unwrap_or("proof_of_equivalence".to_string()) @@ -349,6 +418,7 @@ impl TryFrom for ProofRequest { .map_err(|_| { RaikoError::InvalidRequestConfig("Invalid blob_proof_type".to_string()) })?, + image_id: Some(image_id), prover_args: value.prover_args.into(), }) } @@ -372,6 +442,8 @@ pub struct AggregationRequest { pub proof_type: Option, /// Blob proof type. pub blob_proof_type: Option, + /// The guest image id for RISC0/SP1 provers. + pub image_id: Option, #[serde(flatten)] /// Any additional prover params in JSON format. pub prover_args: ProverSpecificOpts, @@ -403,6 +475,7 @@ impl From for Vec { prover: value.prover.clone(), proof_type: value.proof_type.clone(), blob_proof_type: value.blob_proof_type.clone(), + image_id: value.image_id.clone(), prover_args: value.prover_args.clone(), }, ) @@ -426,11 +499,24 @@ impl From for AggregationRequest { prover: value.prover, proof_type: value.proof_type, blob_proof_type: value.blob_proof_type, + image_id: value.image_id, prover_args: value.prover_args, } } } +impl From<(AggregationRequest, Vec)> for AggregationOnlyRequest { + fn from((request, proofs): (AggregationRequest, Vec)) -> Self { + Self { + proofs, + aggregation_ids: request.block_numbers.iter().map(|(id, _)| *id).collect(), + proof_type: request.proof_type, + image_id: request.image_id, + prover_args: request.prover_args, + } + } +} + #[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)] #[serde(default)] /// A request for proof aggregation of multiple proofs. @@ -441,6 +527,7 @@ pub struct AggregationOnlyRequest { pub proofs: Vec, /// The proof type. pub proof_type: Option, + pub image_id: Option, #[serde(flatten)] /// Any additional prover params in JSON format. pub prover_args: ProverSpecificOpts, @@ -455,17 +542,6 @@ impl Display for AggregationOnlyRequest { } } -impl From<(AggregationRequest, Vec)> for AggregationOnlyRequest { - fn from((request, proofs): (AggregationRequest, Vec)) -> Self { - Self { - proofs, - aggregation_ids: request.block_numbers.iter().map(|(id, _)| *id).collect(), - proof_type: request.proof_type, - prover_args: request.prover_args, - } - } -} - impl AggregationOnlyRequest { /// Merge proof request options into aggregation request options. pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> { diff --git a/core/src/lib.rs b/core/src/lib.rs index fa7adc02d..f64d2106b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -329,6 +329,7 @@ mod tests { proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, prover_args: test_proof_params(false), + image_id: None, }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -358,6 +359,7 @@ mod tests { proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, prover_args: test_proof_params(false), + image_id: None, }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -397,6 +399,7 @@ mod tests { proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, prover_args: test_proof_params(false), + image_id: None, }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -430,6 +433,7 @@ mod tests { proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, prover_args: test_proof_params(false), + image_id: None, }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -460,6 +464,7 @@ mod tests { proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, prover_args: test_proof_params(true), + image_id: None, }; let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; diff --git a/host/Cargo.toml b/host/Cargo.toml index e6cec5ff8..4203562ea 100644 --- a/host/Cargo.toml +++ b/host/Cargo.toml @@ -82,6 +82,8 @@ assert_cmd = { workspace = true } rstest = { workspace = true } ethers-core = { workspace = true } rand = { workspace = true } +test-log = { workspace = true } +serial_test = { workspace = true } [features] default = [] diff --git a/host/src/cache.rs b/host/src/cache.rs index 606a7f4a1..24bf07752 100644 --- a/host/src/cache.rs +++ b/host/src/cache.rs @@ -108,6 +108,7 @@ mod test { blob_proof_type: BlobProofType::KzgVersionedHash, prover_args: Default::default(), l1_inclusion_block_number: 0, + image_id: None, }; let raiko = Raiko::new( l1_chain_spec.clone(), @@ -137,7 +138,8 @@ mod test { provider.provider.get_block_number().await.unwrap() } - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_generate_input_from_cache() { let l1 = &Network::Holesky.to_string(); let l2 = &Network::TaikoA7.to_string(); diff --git a/host/src/lib.rs b/host/src/lib.rs index 930bc0d06..2759c7625 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -10,7 +10,8 @@ use raiko_core::{ merge, }; use raiko_lib::consts::SupportedChainSpecs; -use raiko_tasks::{get_task_manager, ProofTaskDescriptor, TaskManagerOpts, TaskManagerWrapperImpl}; +use raiko_tasks::TaskManager; +use raiko_tasks::{ProofTaskDescriptor, TaskManagerOpts, TaskManagerWrapperImpl}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::sync::mpsc; @@ -154,6 +155,8 @@ pub struct ProverState { pub chain_specs: SupportedChainSpecs, pub task_channel: mpsc::Sender, pause_flag: Arc, + + task_manager: TaskManagerWrapperImpl, } #[derive(Debug)] @@ -198,9 +201,12 @@ impl ProverState { let opts_clone = opts.clone(); let chain_specs_clone = chain_specs.clone(); let sender = task_channel.clone(); + let task_manager = TaskManagerWrapperImpl::new(&opts.clone().into()); + let task_manager_clone = task_manager.clone(); + tokio::spawn(async move { - ProofActor::new(sender, receiver, opts_clone, chain_specs_clone) - .run() + ProofActor::new(sender, opts_clone, chain_specs_clone, task_manager_clone) + .run(receiver) .await; }); @@ -209,11 +215,12 @@ impl ProverState { chain_specs, task_channel, pause_flag, + task_manager, }) } pub fn task_manager(&self) -> TaskManagerWrapperImpl { - get_task_manager(&(&self.opts).into()) + self.task_manager.clone() } pub fn request_config(&self) -> ProofRequestOpt { diff --git a/host/src/proof.rs b/host/src/proof.rs index d01de70ef..b134882e7 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -18,9 +18,7 @@ use raiko_lib::{ prover::{IdWrite, Proof}, Measurement, }; -use raiko_tasks::{ - get_task_manager, ProofTaskDescriptor, TaskManager, TaskManagerWrapperImpl, TaskStatus, -}; +use raiko_tasks::{ProofTaskDescriptor, TaskManager, TaskManagerWrapperImpl, TaskStatus}; use reth_primitives::B256; use tokio::{ select, @@ -43,22 +41,23 @@ use crate::{ Message, Opts, }; +#[derive(Clone)] pub struct ProofActor { opts: Opts, chain_specs: SupportedChainSpecs, aggregate_tasks: Arc>>, running_tasks: Arc>>, pending_tasks: Arc>>, - receiver: Receiver, sender: Sender, + task_manager: TaskManagerWrapperImpl, } impl ProofActor { pub fn new( sender: Sender, - receiver: Receiver, opts: Opts, chain_specs: SupportedChainSpecs, + task_manager: TaskManagerWrapperImpl, ) -> Self { let running_tasks = Arc::new(Mutex::new( HashMap::::new(), @@ -75,11 +74,15 @@ impl ProofActor { aggregate_tasks, running_tasks, pending_tasks, - receiver, sender, + task_manager, } } + pub fn task_manager(&self) -> TaskManagerWrapperImpl { + self.task_manager.clone() + } + pub async fn cancel_task(&mut self, key: ProofTaskDescriptor) -> HostResult<()> { let task = { let tasks_map = self.running_tasks.lock().await; @@ -92,7 +95,7 @@ impl ProofActor { } }; - let mut manager = get_task_manager(&self.opts.clone().into()); + let mut manager = self.task_manager(); cancel_proof( key.proof_system, ( @@ -133,13 +136,14 @@ impl ProofActor { } }; - let key = ProofTaskDescriptor::from(( + let key = ProofTaskDescriptor::new( chain_id, proof_request.block_number, blockhash, proof_request.proof_type, - proof_request.prover.clone().to_string(), - )); + proof_request.prover.to_string(), + proof_request.image_id.clone(), + ); { let mut tasks = self.running_tasks.lock().await; @@ -151,12 +155,13 @@ impl ProofActor { let opts = self.opts.clone(); let chain_specs = self.chain_specs.clone(); + let proof_actor = self.clone(); tokio::spawn(async move { select! { _ = cancel_token.cancelled() => { info!("Task cancelled"); } - result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => { + result = proof_actor.handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => { match result { Ok(status) => { info!("Host handling message: {status:?}"); @@ -188,7 +193,7 @@ impl ProofActor { }; // TODO:(petar) implement cancel_proof_aggregation - // let mut manager = get_task_manager(&self.opts.clone().into()); + // let mut manager = self.task_manager(); // let proof_type = ProofType::from_str( // request // .proof_type @@ -224,12 +229,13 @@ impl ProofActor { let tasks = self.aggregate_tasks.clone(); let opts = self.opts.clone(); + let proof_actor = self.clone(); tokio::spawn(async move { select! { _ = cancel_token.cancelled() => { info!("Task cancelled"); } - result = Self::handle_aggregate(request_clone, &opts) => { + result = proof_actor.handle_aggregate(request_clone, &opts) => { match result { Ok(status) => { info!("Host handling message: {status:?}"); @@ -245,10 +251,10 @@ impl ProofActor { }); } - pub async fn run(&mut self) { + pub async fn run(&mut self, mut receiver: Receiver) { // recv() is protected by outside mpsc, no lock needed here let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); - while let Some(message) = self.receiver.recv().await { + while let Some(message) = receiver.recv().await { match message { Message::Cancel(key) => { debug!("Message::Cancel({key:?})"); @@ -311,12 +317,13 @@ impl ProofActor { } pub async fn handle_message( + &self, proof_request: ProofRequest, key: ProofTaskDescriptor, opts: &Opts, chain_specs: &SupportedChainSpecs, ) -> HostResult { - let mut manager = get_task_manager(&opts.clone().into()); + let mut manager = self.task_manager(); let status = manager.get_task_proving_status(&key).await?; @@ -346,11 +353,15 @@ impl ProofActor { Ok(status) } - pub async fn handle_aggregate(request: AggregationOnlyRequest, opts: &Opts) -> HostResult<()> { + pub async fn handle_aggregate( + &self, + request: AggregationOnlyRequest, + _opts: &Opts, + ) -> HostResult<()> { let proof_type_str = request.proof_type.to_owned().unwrap_or_default(); let proof_type = ProofType::from_str(&proof_type_str).map_err(HostError::Conversion)?; - let mut manager = get_task_manager(&opts.clone().into()); + let mut manager = self.task_manager(); let status = manager .get_aggregation_task_proving_status(&request) @@ -371,7 +382,7 @@ impl ProofActor { }; let output = AggregationGuestOutput { hash: B256::ZERO }; let config = serde_json::to_value(request.clone().prover_args)?; - let mut manager = get_task_manager(&opts.clone().into()); + let mut manager = self.task_manager(); let (status, proof) = match aggregate_proofs(proof_type, input, &output, &config, Some(&mut manager)).await { @@ -575,9 +586,38 @@ pub async fn handle_proof( #[cfg(test)] mod tests { use super::*; + use alloy_primitives::ChainId; + use rand::Rng; use tokio::sync::mpsc; - #[tokio::test] + fn create_test_proof_request() -> ProofRequest { + ProofRequest { + block_number: 1, + l1_inclusion_block_number: 1, + network: "test".to_string(), + l1_network: "test".to_string(), + graffiti: B256::ZERO, + prover: Default::default(), + proof_type: Default::default(), + blob_proof_type: Default::default(), + image_id: None, + prover_args: HashMap::new(), + } + } + + fn create_test_task_descriptor() -> ProofTaskDescriptor { + ProofTaskDescriptor::new( + ChainId::from(1u64), + 1, + B256::default(), + ProofType::Native, + "test".to_string(), + None, + ) + } + + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_handle_system_pause_happy_path() { let (tx, rx) = mpsc::channel(100); let mut actor = setup_actor_with_tasks(tx, rx); @@ -586,23 +626,18 @@ mod tests { assert!(result.is_ok()); } - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_handle_system_pause_with_pending_tasks() { let (tx, rx) = mpsc::channel(100); let mut actor = setup_actor_with_tasks(tx, rx); // Add some pending tasks - actor.pending_tasks.lock().await.push_back(ProofRequest { - block_number: 1, - l1_inclusion_block_number: 1, - network: "test".to_string(), - l1_network: "test".to_string(), - graffiti: B256::ZERO, - prover: Default::default(), - proof_type: Default::default(), - blob_proof_type: Default::default(), - prover_args: HashMap::new(), - }); + actor + .pending_tasks + .lock() + .await + .push_back(create_test_proof_request()); let result = actor.handle_system_pause().await; assert!(result.is_ok()); @@ -611,13 +646,14 @@ mod tests { assert_eq!(actor.pending_tasks.lock().await.len(), 0); } - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_handle_system_pause_with_running_tasks() { let (tx, rx) = mpsc::channel(100); let mut actor = setup_actor_with_tasks(tx, rx); // Add some running tasks - let task_descriptor = ProofTaskDescriptor::default(); + let task_descriptor = create_test_task_descriptor(); let cancellation_token = CancellationToken::new(); actor .running_tasks @@ -635,52 +671,22 @@ mod tests { // assert_eq!(actor.running_tasks.lock().await.len(), 0); } - #[tokio::test] - async fn test_handle_system_pause_with_aggregation_tasks() { + #[serial_test::serial] + #[test_log::test(tokio::test)] + async fn test_handle_system_pause_with_failures() { let (tx, rx) = mpsc::channel(100); let mut actor = setup_actor_with_tasks(tx, rx); - // Add some aggregation tasks - let request = AggregationOnlyRequest::default(); - let cancellation_token = CancellationToken::new(); + // Add some pending tasks actor - .aggregate_tasks + .pending_tasks .lock() .await - .insert(request.clone(), cancellation_token.clone()); - - let result = actor.handle_system_pause().await; - assert!(result.is_ok()); - - // Verify aggregation tasks were cancelled - assert!(cancellation_token.is_cancelled()); - // TODO(Kero): Cancelled tasks should be removed from aggregate_tasks - // assert_eq!(actor.aggregate_tasks.lock().await.len(), 0); - } - - #[tokio::test] - async fn test_handle_system_pause_with_failures() { - let (tx, rx) = mpsc::channel(100); - let mut actor = setup_actor_with_tasks(tx, rx); - - // Add some pending tasks - { - actor.pending_tasks.lock().await.push_back(ProofRequest { - block_number: 1, - l1_inclusion_block_number: 1, - network: "test".to_string(), - l1_network: "test".to_string(), - graffiti: B256::ZERO, - prover: Default::default(), - proof_type: Default::default(), - blob_proof_type: Default::default(), - prover_args: HashMap::new(), - }); - } + .push_back(create_test_proof_request()); let good_running_task_token = { // Add some running tasks - let task_descriptor = ProofTaskDescriptor::default(); + let task_descriptor = create_test_task_descriptor(); let cancellation_token = CancellationToken::new(); actor .running_tasks @@ -727,12 +733,16 @@ mod tests { } // Helper function to setup actor with common test configuration - fn setup_actor_with_tasks(tx: Sender, rx: Receiver) -> ProofActor { + fn setup_actor_with_tasks(tx: Sender, _rx: Receiver) -> ProofActor { + let redis_database = rand::thread_rng().gen_range(0..10000); let opts = Opts { concurrency_limit: 4, + redis_url: format!("redis://localhost:6379/{redis_database}"), + redis_ttl: 3600, ..Default::default() }; + let task_manager = TaskManagerWrapperImpl::new(&opts.clone().into()); - ProofActor::new(tx, rx, opts, SupportedChainSpecs::default()) + ProofActor::new(tx, opts, SupportedChainSpecs::default(), task_manager) } } diff --git a/host/src/server/api/admin.rs b/host/src/server/api/admin.rs index 948e59f70..281e2f2cd 100644 --- a/host/src/server/api/admin.rs +++ b/host/src/server/api/admin.rs @@ -29,7 +29,8 @@ mod tests { use std::path::PathBuf; use tower::ServiceExt; - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_pause() { let opts = { let mut opts = crate::Opts::parse(); @@ -54,7 +55,8 @@ mod tests { assert!(state.is_paused()); } - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_pause_when_already_paused() { let opts = { let mut opts = crate::Opts::parse(); @@ -82,7 +84,8 @@ mod tests { assert!(state.is_paused()); } - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_unpause() { let opts = { let mut opts = crate::Opts::parse(); diff --git a/host/src/server/api/util.rs b/host/src/server/api/util.rs index d47c1da24..f23cd7ee2 100644 --- a/host/src/server/api/util.rs +++ b/host/src/server/api/util.rs @@ -1,3 +1,8 @@ +use raiko_core::interfaces::{ + get_aggregation_image, get_proving_image, AggregationRequest, ProofRequestOpt, +}; +use raiko_lib::{proof_type::ProofType, prover::encode_image_id}; + use crate::{ interfaces::{HostError, HostResult}, ProverState, @@ -10,3 +15,106 @@ pub fn ensure_not_paused(prover_state: &ProverState) -> HostResult<()> { } Ok(()) } + +/// Ensure the image_id is filled for RISC0/SP1, and not filled for Native/SGX. +/// And fill it with the default value for RISC0/SP1 proof type. +pub fn ensure_proof_request_image_id(proof_request_opt: &mut ProofRequestOpt) -> HostResult<()> { + // Parse the proof type string + let proof_type = proof_request_opt + .proof_type + .as_ref() + .ok_or(HostError::InvalidRequestConfig( + "Missing proof_type".to_string(), + ))? + .parse() + .map_err(|_| HostError::InvalidRequestConfig("Invalid proof_type".to_string()))?; + match proof_type { + // For Native/SGX, ensure image_id is None + ProofType::Native | ProofType::Sgx => { + if proof_request_opt.image_id.is_some() { + return Err(HostError::InvalidRequestConfig( + "Native/SGX provers must not have image_id".to_string(), + )); + } + } + // For RISC0/SP1, fill default image_id if None + ProofType::Risc0 | ProofType::Sp1 => { + match &proof_request_opt.image_id { + Some(image_id) => { + // Temporarily workaround for RISC0/SP1 proof type: assert that the image_id is the same with `get_aggregation_image_id()`, + // that means we don't support custom image_id for RISC0/SP1 proof type. + let (_, supported_image_id) = get_proving_image(proof_type)?; + let supported_image_id = encode_image_id(supported_image_id); + if *image_id != supported_image_id { + return Err(HostError::InvalidRequestConfig( + format!( + "Custom image_id is not supported for RISC0/SP1 proof type: actual=({}) != supported=({})", + image_id, supported_image_id + ), + )); + } + } + None => { + // If image_id is None, fill it with the default value + let (_, supported_image_id) = get_proving_image(proof_type)?; + let supported_image_id = encode_image_id(supported_image_id); + proof_request_opt.image_id = Some(supported_image_id); + } + } + } + } + Ok(()) +} + +/// Ensure the image_id is filled for RISC0/SP1, and not filled for Native/SGX. +/// And fill it with the default value for RISC0/SP1 proof type. +pub fn ensure_aggregation_request_image_id( + aggregation_request: &mut AggregationRequest, +) -> HostResult<()> { + // Parse the proof type string + let proof_type = aggregation_request + .proof_type + .as_ref() + .ok_or(HostError::InvalidRequestConfig( + "Missing proof_type".to_string(), + ))? + .parse() + .map_err(|_| HostError::InvalidRequestConfig("Invalid proof_type".to_string()))?; + + match proof_type { + // For Native/SGX, ensure image_id is None + ProofType::Native | ProofType::Sgx => { + if aggregation_request.image_id.is_some() { + return Err(HostError::InvalidRequestConfig( + "Native/SGX provers must not have image_id".to_string(), + )); + } + } + // For RISC0/SP1, fill default image_id if None + ProofType::Risc0 | ProofType::Sp1 => { + match &aggregation_request.image_id { + Some(image_id) => { + // Temporarily workaround for RISC0/SP1 proof type: assert that the image_id is the same with `get_aggregation_image_id()`, + // that means we don't support custom image_id for RISC0/SP1 proof type. + let (_, supported_image_id) = get_aggregation_image(proof_type)?; + let supported_image_id = encode_image_id(supported_image_id); + if *image_id != supported_image_id { + return Err(HostError::InvalidRequestConfig( + format!( + "Custom image_id is not supported for RISC0/SP1 proof type: actual=({}) != supported=({})", + image_id, supported_image_id + ), + )); + } + } + None => { + // If image_id is None, fill it with the default value + let (_, supported_image_id) = get_aggregation_image(proof_type)?; + let supported_image_id = encode_image_id(supported_image_id); + aggregation_request.image_id = Some(supported_image_id); + } + } + } + } + Ok(()) +} diff --git a/host/src/server/api/v1/proof.rs b/host/src/server/api/v1/proof.rs index 1437a2a06..582596d89 100644 --- a/host/src/server/api/v1/proof.rs +++ b/host/src/server/api/v1/proof.rs @@ -1,6 +1,5 @@ use axum::{debug_handler, extract::State, routing::post, Json, Router}; use raiko_core::interfaces::ProofRequest; -use raiko_tasks::get_task_manager; use serde_json::Value; use utoipa::OpenApi; @@ -8,7 +7,10 @@ use crate::{ interfaces::HostResult, metrics::{dec_current_req, inc_current_req, inc_guest_req_count, inc_host_req_count}, proof::handle_proof, - server::api::{util::ensure_not_paused, v1::Status}, + server::api::{ + util::{ensure_not_paused, ensure_proof_request_image_id}, + v1::Status, + }, ProverState, }; @@ -43,12 +45,14 @@ async fn proof_handler( let mut config = prover_state.request_config(); config.merge(&req)?; + ensure_proof_request_image_id(&mut config)?; + // Construct the actual proof request from the available configs. let proof_request = ProofRequest::try_from(config)?; inc_host_req_count(proof_request.block_number); inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); - let mut manager = get_task_manager(&prover_state.opts.clone().into()); + let mut manager = prover_state.task_manager(); handle_proof( &proof_request, diff --git a/host/src/server/api/v2/proof/cancel.rs b/host/src/server/api/v2/proof/cancel.rs index a941df4da..119f3351d 100644 --- a/host/src/server/api/v2/proof/cancel.rs +++ b/host/src/server/api/v2/proof/cancel.rs @@ -4,7 +4,14 @@ use raiko_tasks::{ProofTaskDescriptor, TaskManager, TaskStatus}; use serde_json::Value; use utoipa::OpenApi; -use crate::{interfaces::HostResult, server::api::v2::CancelStatus, Message, ProverState}; +use crate::{ + interfaces::HostResult, + server::api::{ + util::{ensure_not_paused, ensure_proof_request_image_id}, + v2::CancelStatus, + }, + Message, ProverState, +}; #[utoipa::path(post, path = "/proof/cancel", tag = "Proving", @@ -31,6 +38,9 @@ async fn cancel_handler( let mut config = prover_state.request_config(); config.merge(&req)?; + ensure_not_paused(&prover_state)?; + ensure_proof_request_image_id(&mut config)?; + // Construct the actual proof request from the available configs. let proof_request = ProofRequest::try_from(config)?; @@ -41,13 +51,14 @@ async fn cancel_handler( ) .await?; - let key = ProofTaskDescriptor::from(( + let key = ProofTaskDescriptor::new( chain_id, proof_request.block_number, block_hash, proof_request.proof_type, - proof_request.prover.clone().to_string(), - )); + proof_request.prover.to_string(), + proof_request.image_id.clone(), + ); prover_state .task_channel diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index dfcc10e23..19e4f3a1d 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -7,7 +7,10 @@ use utoipa::OpenApi; use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, - server::api::{util::ensure_not_paused, v2::Status}, + server::api::{ + util::{ensure_not_paused, ensure_proof_request_image_id}, + v2::Status, + }, Message, ProverState, }; @@ -45,6 +48,13 @@ async fn proof_handler( let mut config = prover_state.request_config(); config.merge(&req)?; + // TODO: remove this assert after we support custom image_id for RISC0/SP1 proof type + assert!( + config.image_id.is_none(), + "currently we don't support custom image_id for RISC0/SP1 proof type" + ); + ensure_proof_request_image_id(&mut config)?; + // Construct the actual proof request from the available configs. let proof_request = ProofRequest::try_from(config)?; inc_host_req_count(proof_request.block_number); @@ -57,16 +67,23 @@ async fn proof_handler( ) .await?; - let key = ProofTaskDescriptor::from(( + let key = ProofTaskDescriptor::new( chain_id, proof_request.block_number, blockhash, proof_request.proof_type, proof_request.prover.to_string(), - )); + proof_request.image_id.clone(), + ); let mut manager = prover_state.task_manager(); let status = manager.get_task_proving_status(&key).await?; + tracing::info!( + "/v2/proof, request: {:?}, status: {:?}", + proof_request, + status + ); + match status.0.last() { Some((latest_status, ..)) => { match latest_status { diff --git a/host/src/server/api/v3/proof/cancel.rs b/host/src/server/api/v3/proof/cancel.rs index 36dcb9764..8da33e883 100644 --- a/host/src/server/api/v3/proof/cancel.rs +++ b/host/src/server/api/v3/proof/cancel.rs @@ -6,7 +6,14 @@ use raiko_core::{ use raiko_tasks::{ProofTaskDescriptor, TaskManager, TaskStatus}; use utoipa::OpenApi; -use crate::{interfaces::HostResult, server::api::v2::CancelStatus, Message, ProverState}; +use crate::{ + interfaces::HostResult, + server::api::{ + util::{ensure_aggregation_request_image_id, ensure_not_paused}, + v2::CancelStatus, + }, + Message, ProverState, +}; #[utoipa::path(post, path = "/proof/cancel", tag = "Proving", @@ -32,6 +39,9 @@ async fn cancel_handler( // options with the request from the client. aggregation_request.merge(&prover_state.request_config())?; + ensure_not_paused(&prover_state)?; + ensure_aggregation_request_image_id(&mut aggregation_request)?; + let proof_request_opts: Vec = aggregation_request.into(); for opt in proof_request_opts { @@ -44,13 +54,14 @@ async fn cancel_handler( ) .await?; - let key = ProofTaskDescriptor::from(( + let key = ProofTaskDescriptor::new( chain_id, proof_request.block_number, block_hash, proof_request.proof_type, - proof_request.prover.clone().to_string(), - )); + proof_request.prover.to_string(), + proof_request.image_id.clone(), + ); prover_state .task_channel diff --git a/host/src/server/api/v3/proof/mod.rs b/host/src/server/api/v3/proof/mod.rs index e8668e779..37e5e349d 100644 --- a/host/src/server/api/v3/proof/mod.rs +++ b/host/src/server/api/v3/proof/mod.rs @@ -1,8 +1,11 @@ use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, - server::api::util::ensure_not_paused, - server::api::{v2, v3::Status}, + server::api::{ + util::{ensure_aggregation_request_image_id, ensure_not_paused}, + v2, + v3::Status, + }, Message, ProverState, }; use axum::{debug_handler, extract::State, routing::post, Json, Router}; @@ -42,6 +45,13 @@ async fn proof_handler( ensure_not_paused(&prover_state)?; + // TODO: remove this assert after we support custom image_id for RISC0/SP1 proof type + assert!( + aggregation_request.image_id.is_none(), + "currently we don't support custom image_id for RISC0/SP1 proof type" + ); + ensure_aggregation_request_image_id(&mut aggregation_request)?; + // Override the existing proof request config from the config file and command line // options with the request from the client. aggregation_request.merge(&prover_state.request_config())?; @@ -68,13 +78,14 @@ async fn proof_handler( ) .await?; - let key = ProofTaskDescriptor::from(( + let key = ProofTaskDescriptor::new( chain_id, proof_request.block_number, blockhash, proof_request.proof_type, proof_request.prover.to_string(), - )); + proof_request.image_id.clone(), + ); tasks.push((key, proof_request)); } @@ -87,6 +98,7 @@ async fn proof_handler( for (key, req) in tasks.iter() { let status = manager.get_task_proving_status(key).await?; + tracing::info!("/v3/proof, request: {:?}, status: {:?}", req, status); if let Some((latest_status, ..)) = status.0.last() { match latest_status { @@ -150,6 +162,7 @@ async fn proof_handler( proofs, proof_type: aggregation_request.proof_type, prover_args: aggregation_request.prover_args, + image_id: aggregation_request.image_id, }; let status = manager @@ -243,7 +256,8 @@ mod tests { use std::path::PathBuf; use tower::ServiceExt; - #[tokio::test] + #[serial_test::serial] + #[test_log::test(tokio::test)] async fn test_proof_handler_when_paused() { let opts = { let mut opts = crate::Opts::parse(); diff --git a/host/tests/common/request.rs b/host/tests/common/request.rs index d807587a1..7e58e84e1 100644 --- a/host/tests/common/request.rs +++ b/host/tests/common/request.rs @@ -1,8 +1,10 @@ -use raiko_core::interfaces::{AggregationOnlyRequest, ProofRequestOpt, ProverSpecificOpts}; +use raiko_core::interfaces::{ + get_aggregation_image, AggregationOnlyRequest, ProofRequestOpt, ProverSpecificOpts, +}; use raiko_host::server::api; -use raiko_lib::consts::Network; use raiko_lib::proof_type::ProofType; use raiko_lib::prover::Proof; +use raiko_lib::{consts::Network, prover::encode_image_id}; use raiko_tasks::{TaskDescriptor, TaskReport, TaskStatus}; use serde_json::json; @@ -20,6 +22,13 @@ pub fn make_proof_request( block_number, std::time::Instant::now().elapsed().as_secs() ); + let image_id = match proof_type { + ProofType::Sp1 | ProofType::Risc0 => { + let (_, image_id) = get_aggregation_image(*proof_type).unwrap(); + Some(encode_image_id(image_id)) + } + ProofType::Native | ProofType::Sgx => None, + }; ProofRequestOpt { block_number: Some(block_number), network: Some(network.to_string()), @@ -46,6 +55,7 @@ pub fn make_proof_request( sgx: None, sp1: None, }, + image_id, } } @@ -66,6 +76,13 @@ pub async fn make_aggregate_proof_request( .join(","), std::time::Instant::now().elapsed().as_secs() ); + let image_id = match proof_type { + ProofType::Sp1 | ProofType::Risc0 => { + let (_, image_id) = get_aggregation_image(*proof_type).unwrap(); + Some(encode_image_id(image_id)) + } + ProofType::Native | ProofType::Sgx => None, + }; AggregationOnlyRequest { aggregation_ids: block_numbers, proofs, @@ -83,6 +100,7 @@ pub async fn make_aggregate_proof_request( sgx: None, sp1: None, }, + image_id, } } @@ -246,7 +264,8 @@ pub async fn get_status_of_aggregation_proof_request( client: &Client, request: &AggregationOnlyRequest, ) -> TaskStatus { - let expected_task_descriptor: TaskDescriptor = TaskDescriptor::Aggregation(request.into()); + let expected_task_descriptor: TaskDescriptor = + TaskDescriptor::Aggregation(request.try_into().unwrap()); let report = v2_assert_report(client).await; for (task_descriptor, task_status) in &report { if task_descriptor == &expected_task_descriptor { diff --git a/host/tests/common/server.rs b/host/tests/common/server.rs index 7bc0f80ab..e55f13130 100644 --- a/host/tests/common/server.rs +++ b/host/tests/common/server.rs @@ -11,11 +11,13 @@ use tokio_util::sync::CancellationToken; /// ``` /// let server = TestServerBuilder::default() /// .port(8080) +/// .redis_url("redis://127.0.0.1:6379/0".to_string()) /// .build(); /// ``` #[derive(Default, Debug)] pub struct TestServerBuilder { port: Option, + redis_url: Option, } impl TestServerBuilder { @@ -24,11 +26,19 @@ impl TestServerBuilder { self } + pub fn redis_url(mut self, redis_url: String) -> Self { + self.redis_url = Some(redis_url); + self + } + pub fn build(self) -> TestServerHandle { let port = self .port .unwrap_or(rand::thread_rng().gen_range(1024..65535)); let address = format!("127.0.0.1:{port}"); + let redis_url = self + .redis_url + .unwrap_or("redis://localhost:6379/0".to_string()); // TODO // opts.config_path @@ -38,7 +48,9 @@ impl TestServerBuilder { address, log_level, + redis_url, concurrency_limit: 16, + redis_ttl: 3600, ..Default::default() }; let state = ProverState::init_with_opts(opts).expect("Failed to initialize prover state"); diff --git a/host/tests/common/setup.rs b/host/tests/common/setup.rs index 0ec82564c..703ee1747 100644 --- a/host/tests/common/setup.rs +++ b/host/tests/common/setup.rs @@ -2,9 +2,16 @@ use crate::common::Client; use crate::common::{TestServerBuilder, TestServerHandle}; use rand::Rng; +pub const REDIS_URL_PREFIX: &str = "redis://localhost:6379/"; + +// TODO: make sure redis is not used by other tests pub async fn setup() -> (TestServerHandle, Client) { let port = rand::thread_rng().gen_range(1024..65535); - let server = TestServerBuilder::default().port(port).build(); + let redis_database = port % 15; // port is randomly generated, so it can be used as redis database + let server = TestServerBuilder::default() + .port(port) + .redis_url(format!("{REDIS_URL_PREFIX}{redis_database}")) + .build(); let client = server.get_client(); // Wait for the server to be ready diff --git a/host/tests/test/aggregate_test.rs b/host/tests/test/aggregate_test.rs index 1dbbd1920..ac2a483ab 100644 --- a/host/tests/test/aggregate_test.rs +++ b/host/tests/test/aggregate_test.rs @@ -7,14 +7,16 @@ use raiko_lib::consts::Network; use raiko_lib::proof_type::ProofType; use raiko_tasks::TaskStatus; -#[tokio::test] +#[serial_test::serial] +#[test_log::test(tokio::test)] async fn test_v2_mainnet_aggregate_native() { test_v2_mainnet_aggregate(Network::TaikoMainnet, ProofType::Native).await; } #[ignore] +#[serial_test::serial] #[cfg(feature = "risc0")] -#[tokio::test] +#[test_log::test(tokio::test)] async fn test_v2_mainnet_aggregate_risc0() { test_v2_mainnet_aggregate(Network::TaikoMainnet, ProofType::Risc0).await; } @@ -23,7 +25,7 @@ async fn test_v2_mainnet_aggregate(network: Network, proof_type: ProofType) { setup_mock_zkvm_elf(); let api_version = "v2"; - let aggregate_block_count = 2; + let aggregate_block_count = 1; let block_numbers = randomly_select_blocks(network, aggregate_block_count) .await diff --git a/host/tests/test/cancel_test.rs b/host/tests/test/cancel_test.rs index b5adf04a1..adcbcb204 100644 --- a/host/tests/test/cancel_test.rs +++ b/host/tests/test/cancel_test.rs @@ -4,7 +4,9 @@ use raiko_lib::consts::Network; use raiko_lib::proof_type::ProofType; use raiko_tasks::TaskStatus; -#[tokio::test] +#[ignore] +#[serial_test::serial] +#[test_log::test(tokio::test)] pub async fn test_v2_mainnet_native_cancel() { let api_version = "v2"; let network = Network::TaikoMainnet; @@ -47,7 +49,9 @@ pub async fn test_v2_mainnet_native_cancel() { assert!(matches!(status, api::v2::CancelStatus::Ok),); } -#[tokio::test] +#[ignore] +#[serial_test::serial] +#[test_log::test(tokio::test)] pub async fn test_v2_mainnet_native_cancel_non_registered() { let api_version = "v2"; let network = Network::TaikoMainnet; @@ -70,7 +74,9 @@ pub async fn test_v2_mainnet_native_cancel_non_registered() { ); } -#[tokio::test] +#[ignore] +#[serial_test::serial] +#[test_log::test(tokio::test)] pub async fn test_v2_mainnet_native_cancel_then_register() { let api_version = "v2"; let network = Network::TaikoMainnet; diff --git a/host/tests/test/manual_test.rs b/host/tests/test/manual_test.rs index dc73e0c4a..a345af745 100644 --- a/host/tests/test/manual_test.rs +++ b/host/tests/test/manual_test.rs @@ -1,8 +1,10 @@ use crate::common::{complete_proof_request, v2_assert_report, Client}; -use raiko_core::interfaces::{ProofRequestOpt, ProverSpecificOpts}; +use raiko_core::interfaces::{get_aggregation_image, ProofRequestOpt, ProverSpecificOpts}; use raiko_host::server::api; +use raiko_lib::{proof_type::ProofType, prover::encode_image_id}; use raiko_tasks::TaskStatus; use serde_json::json; +use std::str::FromStr; /// This test is used to manually test the proof process. Operator can use this to test case to /// simplly test online service. @@ -19,7 +21,7 @@ use serde_json::json; /// RAIKO_TEST_MANUAL_PROVE_RAIKO_RPC_URL=https://rpc.raiko.xyz \ /// cargo test --test test_manual_prove -- --ignored /// ``` -#[tokio::test] +#[test_log::test(tokio::test)] #[ignore] pub async fn test_manual_prove() { let enabled = std::env::var("RAIKO_TEST_MANUAL_PROVE_ENABLED").unwrap_or_default() == "false"; @@ -32,16 +34,26 @@ pub async fn test_manual_prove() { let api_version = std::env::var("RAIKO_TEST_MANUAL_PROVE_API_VERSION").unwrap_or_default(); let network = std::env::var("RAIKO_TEST_MANUAL_PROVE_NETWORK").unwrap_or_default(); let proof_type = std::env::var("RAIKO_TEST_MANUAL_PROVE_PROOF_TYPE").unwrap_or_default(); + let proof_type = ProofType::from_str(&proof_type).unwrap(); let block_number = std::env::var("RAIKO_TEST_MANUAL_PROVE_BLOCK_NUMBER") .map(|s| s.parse::().unwrap()) .unwrap(); let raiko_rpc_url = std::env::var("RAIKO_TEST_MANUAL_PROVE_RAIKO_RPC_URL").unwrap_or_default(); + let image_id = match proof_type { + ProofType::Sp1 | ProofType::Risc0 => { + let (_, image_id) = get_aggregation_image(proof_type).unwrap(); + Some(encode_image_id(image_id)) + } + ProofType::Native | ProofType::Sgx => None, + }; + let client = Client::new(raiko_rpc_url.clone()); let request = ProofRequestOpt { block_number: Some(block_number), network: Some(network.clone()), - proof_type: Some(proof_type.clone()), + proof_type: Some(proof_type.to_string()), + image_id, // Untesting parameters l1_inclusion_block_number: None, diff --git a/host/tests/test/prove_test.rs b/host/tests/test/prove_test.rs index 42da83735..6722b0210 100644 --- a/host/tests/test/prove_test.rs +++ b/host/tests/test/prove_test.rs @@ -6,7 +6,9 @@ use raiko_lib::consts::Network; use raiko_lib::proof_type::ProofType; use raiko_tasks::TaskStatus; -#[tokio::test] +#[ignore] +#[serial_test::serial] +#[test_log::test(tokio::test)] pub async fn test_v2_mainnet_native_prove() { let api_version = "v2"; let network = Network::TaikoMainnet; @@ -14,7 +16,7 @@ pub async fn test_v2_mainnet_native_prove() { let block_number = randomly_select_block(network) .await .expect("randomly select block failed"); - println!( + tracing::info!( "test_prove_v2_mainnet_native network: {network}, proof_type: {proof_type}, block_number: {block_number}" ); diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 33b3856b0..0493c2fd6 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -38,6 +38,7 @@ kzg_traits = { workspace = true } sha2 = { workspace = true } sha3 = { workspace = true } rlp = { workspace = true, features = ["std"] } +bytemuck = { workspace = true } # docs utoipa = { workspace = true } @@ -75,4 +76,4 @@ std = [ sgx = [] sp1 = [] risc0 = [] -sp1-cycle-tracker = [] \ No newline at end of file +sp1-cycle-tracker = [] diff --git a/lib/src/proof_type.rs b/lib/src/proof_type.rs index c1c45dc76..00c6fd116 100644 --- a/lib/src/proof_type.rs +++ b/lib/src/proof_type.rs @@ -48,7 +48,7 @@ impl std::str::FromStr for ProofType { "sp1" => Ok(ProofType::Sp1), "sgx" => Ok(ProofType::Sgx), "risc0" => Ok(ProofType::Risc0), - _ => Err(format!("Unknown proof type {}", s)), + _ => Err(format!("Unknown proof type: {}", s)), } } } diff --git a/lib/src/prover.rs b/lib/src/prover.rs index 0b1bf3498..8dbcb6933 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -70,4 +70,55 @@ pub trait Prover { ) -> ProverResult; async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>; + + /// The image id and ELF of current proving guest program. + fn current_proving_image() -> (&'static [u8], &'static [u32; 8]) { + (&[], &[0; 8]) + } + + /// The image id and ELF of current aggregation guest program. + fn current_aggregation_image() -> (&'static [u8], &'static [u32; 8]) { + (&[], &[0; 8]) + } +} + +/// A helper function to encode image id to hex string. +pub fn encode_image_id(image_id: &[u32; 8]) -> String { + let bytes = bytemuck::cast_slice(image_id); + hex::encode(bytes) +} + +/// A helper function to decode image id from hex string. +pub fn decode_image_id(hex_image_id: &str) -> Result<[u32; 8], hex::FromHexError> { + let bytes: Vec = if let Some(stripped) = hex_image_id.strip_prefix("0x") { + hex::decode(stripped)? + } else { + hex::decode(hex_image_id)? + }; + let array: &[u32] = bytemuck::cast_slice::(&bytes); + let result: [u32; 8] = array.try_into().expect("invalid hex image id"); + Ok(result) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_encode_image_id() { + let image_id: [u32; 8] = [ + 1848002361, 3447634449, 2932177819, 2827220601, 4284138344, 2572487667, 1602600202, + 3769687346, + ]; + let encoded = encode_image_id(&image_id); + assert_eq!( + encoded, + "3947266e11ba7ecd9b7bc5ae79f683a868c35afff30b55990abd855f32ddb0e0" + ); + let decoded = decode_image_id(&encoded).unwrap(); + assert_eq!(decoded, image_id); + + let decoded = decode_image_id(&format!("0x{encoded}")).unwrap(); + assert_eq!(decoded, image_id); + } } diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 9dad15517..e89f21c04 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -2,10 +2,6 @@ #[cfg(feature = "bonsai-auto-scaling")] use crate::bonsai::auto_scaling::shutdown_bonsai; -use crate::{ - methods::risc0_aggregation::RISC0_AGGREGATION_ELF, - methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, -}; use alloy_primitives::{hex::ToHexExt, B256}; use bonsai::{cancel_proof, maybe_prove}; use log::{info, warn}; @@ -15,7 +11,10 @@ use raiko_lib::{ ZkAggregationGuestInput, }, proof_type::ProofType, - prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, + prover::{ + encode_image_id, IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, + ProverResult, + }, }; use risc0_zkvm::{ compute_image_id, default_prover, serde::to_vec, sha::Digestible, ExecutorEnv, ProverOpts, @@ -24,7 +23,6 @@ use risc0_zkvm::{ use serde::{Deserialize, Serialize}; use serde_with::serde_as; use std::fmt::Debug; -use tracing::debug; pub mod bonsai; pub mod methods; @@ -77,13 +75,20 @@ impl Prover for Risc0Prover { ProofType::Risc0 as u8, ); - debug!("elf code length: {}", RISC0_GUEST_ELF.len()); + let (elf, image_id) = Risc0Prover::current_proving_image(); + + info!( + "Using risc0 image id: {}, elf.length: {}", + encode_image_id(image_id), + elf.len() + ); + let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); let (uuid, receipt) = maybe_prove::( &config, encoded_input, - RISC0_GUEST_ELF, + elf, &output.hash, (Vec::::new(), Vec::new()), proof_key, @@ -132,6 +137,20 @@ impl Prover for Risc0Prover { "Aggregation must be in bonsai snark mode" ); + let (proving_elf, proving_image_id) = Risc0Prover::current_proving_image(); + let (aggregation_elf, aggregation_image_id) = Risc0Prover::current_aggregation_image(); + + info!( + "Using risc0 proving image id: {}, elf.length: {}", + encode_image_id(proving_image_id), + proving_elf.len() + ); + info!( + "Using risc0 aggregation image id: {}, elf.length: {}", + encode_image_id(aggregation_image_id), + aggregation_elf.len() + ); + // Extract the block proof receipts let assumptions: Vec = input .proofs @@ -148,7 +167,8 @@ impl Prover for Risc0Prover { .map(|proof| proof.input.unwrap()) .collect::>(); let input = ZkAggregationGuestInput { - image_id: RISC0_GUEST_ID, + // TODO(Kero): use input.image_id + image_id: *proving_image_id, block_inputs, }; info!("Start aggregate proofs"); @@ -163,7 +183,7 @@ impl Prover for Risc0Prover { let opts = ProverOpts::groth16(); let receipt = default_prover() - .prove_with_opts(env, RISC0_AGGREGATION_ELF, &opts) + .prove_with_opts(env, aggregation_elf, &opts) .unwrap() .receipt; @@ -171,8 +191,8 @@ impl Prover for Risc0Prover { "Generate aggregation receipt journal: {:?}", alloy_primitives::hex::encode_prefixed(receipt.journal.bytes.clone()) ); - let block_proof_image_id = compute_image_id(RISC0_GUEST_ELF).unwrap(); - let aggregation_image_id = compute_image_id(RISC0_AGGREGATION_ELF).unwrap(); + let block_proof_image_id = compute_image_id(proving_elf).unwrap(); + let aggregation_image_id = compute_image_id(aggregation_elf).unwrap(); let proof_data = snarks::verify_aggregation_groth16_proof( block_proof_image_id, aggregation_image_id, @@ -218,12 +238,21 @@ impl Prover for Risc0Prover { .map_err(|e| ProverError::GuestError(e.to_string()))?; id_store.remove_id(key).await } + + fn current_proving_image() -> (&'static [u8], &'static [u32; 8]) { + use crate::methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}; + (RISC0_GUEST_ELF, &RISC0_GUEST_ID) + } + + fn current_aggregation_image() -> (&'static [u8], &'static [u32; 8]) { + use crate::methods::risc0_aggregation::{RISC0_AGGREGATION_ELF, RISC0_AGGREGATION_ID}; + (RISC0_AGGREGATION_ELF, &RISC0_AGGREGATION_ID) + } } #[cfg(test)] mod test { use super::*; - use methods::risc0_guest::RISC0_GUEST_ID; use methods::test_risc0_guest::{TEST_RISC0_GUEST_ELF, TEST_RISC0_GUEST_ID}; use risc0_zkvm::{default_prover, ExecutorEnv}; @@ -239,9 +268,15 @@ mod test { #[ignore = "only to print image id for docker image build"] #[test] fn test_show_risc0_image_id() { - let image_id = RISC0_GUEST_ID - .map(|limp| hex::encode(limp.to_le_bytes())) - .concat(); - println!("RISC0 IMAGE_ID: {}", image_id); + let (_, proving_image_id) = Risc0Prover::current_proving_image(); + let (_, aggregation_image_id) = Risc0Prover::current_aggregation_image(); + println!( + "RISC0 PROVING IMAGE_ID: {}", + encode_image_id(proving_image_id) + ); + println!( + "RISC0 AGGREGATION IMAGE_ID: {}", + encode_image_id(aggregation_image_id) + ); } } diff --git a/provers/sp1/contracts/README.md b/provers/sp1/contracts/README.md index 72b7e3205..a87523154 100644 --- a/provers/sp1/contracts/README.md +++ b/provers/sp1/contracts/README.md @@ -25,7 +25,7 @@ The contract and proof is generated by [./provers/sp1/driver/src/gen_verifier.rs RUST_LOG=debug ``` - generate proof given `GuestInput` in [input.json](../contracts/src/fixtures/input.json) either remotely or locally -- serealize the proof into [fixture.json](../contracts/src/fixtures/fixture.json) to be tested in `RaikoVerifier.sol` +- serialize the proof into [fixture.json](../contracts/src/fixtures/fixture.json) to be tested in `RaikoVerifier.sol` ### To verify a different block You can either start Raiko and run [prove-block.sh](../../../script/prove-block.sh) to get the block you want from certain network and then run the `run-verifier` to prove. Make sure the prover is not in `mock` mode. Finally, you can verify with smart contract test. @@ -56,4 +56,4 @@ $ cargo run --bin sp1-verifier -- input-taiko-mainnet-192317.json You can also run [./script/sp1-e2e.sh](/script/sp1-e2e.sh) which does the same thing: ``` $ ./script/sp1-e2e.sh taiko_mainnet 192317 -``` \ No newline at end of file +``` diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index ddc691e78..d07cbab9c 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -9,7 +9,10 @@ use raiko_lib::{ ZkAggregationGuestInput, }, proof_type::ProofType, - prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, + prover::{ + encode_image_id, IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, + ProverResult, + }, Measurement, }; use reth_primitives::B256; @@ -31,9 +34,6 @@ use tracing::{debug, error, info}; mod proof_verify; use proof_verify::remote_contract_verify::verify_sol_by_contract_call; -pub const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); -pub const AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation"); - #[serde_as] #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Sp1Param { @@ -129,6 +129,7 @@ impl Prover for Sp1Prover { let mut stdin = SP1Stdin::new(); stdin.write(&input); + let (elf, image_id) = Self::current_proving_image(); let Sp1ProverClient { client, pk, vk } = BLOCK_PROOF_CLIENT .entry(mode.clone()) .or_insert_with(|| { @@ -139,21 +140,17 @@ impl Prover for Sp1Prover { }; let client = Arc::new(base_client); - let (pk, vk) = client.setup(ELF); + let (pk, vk) = client.setup(elf); info!( - "new client and setup() for block {:?}.", - output.header.number + "Sp1 Prover: block {:?} with vk {:?} image_id {}", + output.header.number, + vk.bytes32(), + encode_image_id(image_id) ); Sp1ProverClient { client, pk, vk } }) .clone(); - info!( - "Sp1 Prover: block {:?} with vk {:?}", - output.header.number, - vk.bytes32() - ); - let prove_action = action::Prove::new(client.prover.as_ref(), &pk, stdin.clone()); let prove_result = if !matches!(mode, ProverMode::Network) { debug!("Proving locally with recursion mode: {:?}", param.recursion); @@ -167,7 +164,7 @@ impl Prover for Sp1Prover { let network_prover = NetworkProver::new(); let proof_id = network_prover - .request_proof(ELF, stdin, param.recursion.clone().into()) + .request_proof(elf, stdin, param.recursion.clone().into()) .await .map_err(|e| { ProverError::GuestError(format!("Sp1: requesting proof failed: {e}")) @@ -323,7 +320,9 @@ impl Prover for Sp1Prover { .unwrap_or_else(ProverClient::new); let client = Arc::new(base_client); - let (pk, vk) = client.setup(AGGREGATION_ELF); + let (aggregation_elf, _) = Self::current_aggregation_image(); + let (pk, vk) = client.setup(aggregation_elf); + info!( "new client and setup() for aggregation based on {:?} proofs with vk {:?}", input.proofs.len(), @@ -380,6 +379,32 @@ impl Prover for Sp1Prover { .into(), ) } + + fn current_proving_image() -> (&'static [u8], &'static [u32; 8]) { + const PROVING_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); + + static PROVING_IMAGE_ID: once_cell::sync::Lazy<[u32; 8]> = + once_cell::sync::Lazy::new(|| { + let local_client = ProverClient::local(); + let (_, vk) = local_client.setup(PROVING_ELF); + vk.hash_u32() + }); + + (PROVING_ELF, &PROVING_IMAGE_ID) + } + + fn current_aggregation_image() -> (&'static [u8], &'static [u32; 8]) { + static AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation"); + + static AGGREGATION_IMAGE_ID: once_cell::sync::Lazy<[u32; 8]> = + once_cell::sync::Lazy::new(|| { + let local_client = ProverClient::local(); + let (_, vk) = local_client.setup(AGGREGATION_ELF); + vk.hash_u32() + }); + + (AGGREGATION_ELF, &AGGREGATION_IMAGE_ID) + } } fn get_env_mock() -> ProverMode { diff --git a/script/build.sh b/script/build.sh index b469de31f..fd799d7d5 100755 --- a/script/build.sh +++ b/script/build.sh @@ -46,6 +46,11 @@ if [ "$CPU_OPT" = "1" ]; then echo "Enable cpu optimization with host RUSTFLAGS" fi +# install cargo-nextest if not installed +if ! cargo nextest --version >/dev/null 2>&1; then + cargo install cargo-nextest@0.9.85 --locked +fi + # NATIVE if [ -z "$1" ] || [ "$1" == "native" ]; then if [ -n "${CLIPPY}" ]; then diff --git a/taskdb/src/lib.rs b/taskdb/src/lib.rs index 513843a13..49ab410ee 100644 --- a/taskdb/src/lib.rs +++ b/taskdb/src/lib.rs @@ -1,4 +1,7 @@ -use std::io::{Error as IOError, ErrorKind as IOErrorKind}; +use std::{ + io::{Error as IOError, ErrorKind as IOErrorKind}, + str::FromStr, +}; use chrono::{DateTime, Utc}; use raiko_core::interfaces::AggregationOnlyRequest; @@ -8,7 +11,6 @@ use raiko_lib::{ prover::{IdStore, IdWrite, ProofKey, ProverResult}, }; use serde::{Deserialize, Serialize}; -use tracing::debug; use utoipa::ToSchema; #[cfg(feature = "in-memory")] @@ -147,24 +149,31 @@ pub struct ProofTaskDescriptor { pub blockhash: B256, pub proof_system: ProofType, pub prover: String, + pub image_id: Option, } -impl From<(ChainId, u64, B256, ProofType, String)> for ProofTaskDescriptor { - fn from( - (chain_id, block_id, blockhash, proof_system, prover): ( - ChainId, - u64, - B256, - ProofType, - String, - ), +impl ProofTaskDescriptor { + /// Create a new ProofTaskDescriptor. + /// For RISC0 and SP1 provers, image_id should be provided. + pub fn new( + chain_id: ChainId, + block_id: u64, + blockhash: B256, + proof_system: ProofType, + prover: String, + image_id: Option, ) -> Self { - ProofTaskDescriptor { + debug_assert!( + matches!(proof_system, ProofType::Native | ProofType::Sgx) || image_id.is_some(), + "RISC0/SP1 provers require image_id to be provided" + ); + Self { chain_id, block_id, blockhash, proof_system, prover, + image_id, } } } @@ -183,20 +192,38 @@ impl From for (ChainId, B256) { #[derive(Default, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] #[serde(default)] -/// A request for proof aggregation of multiple proofs. pub struct AggregationTaskDescriptor { - /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. pub aggregation_ids: Vec, - /// The proof type. pub proof_type: Option, + pub image_id: Option, } -impl From<&AggregationOnlyRequest> for AggregationTaskDescriptor { - fn from(request: &AggregationOnlyRequest) -> Self { - Self { - aggregation_ids: request.aggregation_ids.clone(), - proof_type: request.proof_type.clone().map(|p| p.to_string()), +impl TryFrom<&AggregationOnlyRequest> for AggregationTaskDescriptor { + type Error = String; + + fn try_from(request: &AggregationOnlyRequest) -> Result { + // Check if we need an image ID for this proof type + if let Some(pt) = request + .proof_type + .as_ref() + .and_then(|pt| ProofType::from_str(pt).ok()) + { + match pt { + ProofType::Risc0 | ProofType::Sp1 if request.image_id.is_none() => { + return Err("RISC0/SP1 provers require image_id to be provided".to_string()); + } + ProofType::Native | ProofType::Sgx if request.image_id.is_some() => { + return Err("Native/SGX provers must not have image_id".to_string()); + } + _ => {} + } } + + Ok(Self { + aggregation_ids: request.aggregation_ids.clone(), + proof_type: request.proof_type.clone(), + image_id: request.image_id.clone(), + }) } } @@ -224,7 +251,7 @@ pub struct TaskManagerOpts { } #[async_trait::async_trait] -pub trait TaskManager: IdStore + IdWrite + Send + Sync { +pub trait TaskManager: IdStore + IdWrite + Send + Sync + Clone { /// Create a new task manager. fn new(opts: &TaskManagerOpts) -> Self; @@ -304,6 +331,7 @@ pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { Ok(()) } +#[derive(Debug, Clone)] pub struct TaskManagerWrapper { manager: T, } @@ -423,14 +451,9 @@ impl TaskManager for TaskManagerWrapper { #[cfg(feature = "in-memory")] pub type TaskManagerWrapperImpl = TaskManagerWrapper; -#[cfg(feature = "redis-db")] +#[cfg(all(feature = "redis-db", not(feature = "in-memory")))] pub type TaskManagerWrapperImpl = TaskManagerWrapper; -pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapperImpl { - debug!("get task manager with options: {:?}", opts); - TaskManagerWrapperImpl::new(opts) -} - #[cfg(test)] mod test { use super::*; @@ -443,18 +466,19 @@ mod test { redis_url: "redis://localhost:6379".to_string(), redis_ttl: 3600, }; - let mut task_manager = get_task_manager(&opts); + let mut task_manager = TaskManagerWrapperImpl::new(&opts); let block_id = rand::thread_rng().gen_range(0..1000000); assert_eq!( task_manager - .enqueue_task(&ProofTaskDescriptor { - chain_id: 1, + .enqueue_task(&ProofTaskDescriptor::new( + ChainId::from(1u64), block_id, - blockhash: B256::default(), - proof_system: ProofType::Native, - prover: "test".to_string(), - }) + B256::default(), + ProofType::Native, + "test".to_string(), + None, + )) .await .unwrap() .0 @@ -470,15 +494,16 @@ mod test { redis_url: "redis://localhost:6379".to_string(), redis_ttl: 3600, }; - let mut task_manager = get_task_manager(&opts); + let mut task_manager = TaskManagerWrapperImpl::new(&opts); let block_id = rand::thread_rng().gen_range(0..1000000); - let key = ProofTaskDescriptor { - chain_id: 1, + let key = ProofTaskDescriptor::new( + ChainId::from(1u64), block_id, - blockhash: B256::default(), - proof_system: ProofType::Native, - prover: "test".to_string(), - }; + B256::default(), + ProofType::Native, + "test".to_string(), + None, + ); assert_eq!(task_manager.enqueue_task(&key).await.unwrap().0.len(), 1); // enqueue again diff --git a/taskdb/src/mem_db.rs b/taskdb/src/mem_db.rs index 8ebab1b7e..31bd5c93e 100644 --- a/taskdb/src/mem_db.rs +++ b/taskdb/src/mem_db.rs @@ -7,10 +7,7 @@ // Imports // ---------------------------------------------------------------- -use std::{ - collections::HashMap, - sync::{Arc, Once}, -}; +use std::{collections::HashMap, sync::Arc}; use chrono::Utc; use raiko_core::interfaces::AggregationOnlyRequest; @@ -24,7 +21,7 @@ use crate::{ TaskReport, TaskStatus, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct InMemoryTaskManager { db: Arc>, } @@ -138,7 +135,11 @@ impl InMemoryTaskDb { .filter_map(|(desc, statuses)| { statuses.0.last().map(|s| { ( - TaskDescriptor::Aggregation(AggregationTaskDescriptor::from(desc)), + TaskDescriptor::Aggregation( + AggregationTaskDescriptor::try_from(desc).expect( + "invalid AggregationOnlyRequest detected in Memory DB, please report this issue", + ), + ), s.0.clone(), ) }) @@ -300,8 +301,9 @@ impl IdStore for InMemoryTaskManager { #[async_trait::async_trait] impl TaskManager for InMemoryTaskManager { + #[cfg(not(test))] fn new(_opts: &TaskManagerOpts) -> Self { - static INIT: Once = Once::new(); + static INIT: std::sync::Once = std::sync::Once::new(); static mut SHARED_TASK_MANAGER: Option>> = None; INIT.call_once(|| { @@ -312,11 +314,20 @@ impl TaskManager for InMemoryTaskManager { } }); + tracing::info!("InMemoryTaskManager created not test"); InMemoryTaskManager { db: unsafe { SHARED_TASK_MANAGER.clone().unwrap() }, } } + #[cfg(test)] + fn new(_opts: &TaskManagerOpts) -> Self { + tracing::info!("InMemoryTaskManager created test"); + InMemoryTaskManager { + db: Arc::new(Mutex::new(InMemoryTaskDb::new())), + } + } + async fn enqueue_task( &mut self, params: &ProofTaskDescriptor, @@ -438,13 +449,14 @@ mod tests { #[test] fn test_db_enqueue() { let mut db = InMemoryTaskDb::new(); - let params = ProofTaskDescriptor { - chain_id: 1, - block_id: 1, - blockhash: B256::default(), - proof_system: ProofType::Native, - prover: "0x1234".to_owned(), - }; + let params = ProofTaskDescriptor::new( + 1, + 1, + B256::default(), + ProofType::Native, + "0x1234".to_owned(), + None, + ); db.enqueue_task(¶ms).expect("enqueue task"); let status = db.get_task_proving_status(¶ms); assert!(status.is_ok()); diff --git a/taskdb/src/redis_db.rs b/taskdb/src/redis_db.rs index cddc3be42..cd10181ae 100644 --- a/taskdb/src/redis_db.rs +++ b/taskdb/src/redis_db.rs @@ -16,7 +16,7 @@ use redis::{ Client, Commands, ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value, }; -use std::sync::{Arc, Once}; +use std::sync::Arc; use std::time::Duration; use thiserror::Error; use tokio::sync::Mutex; @@ -33,10 +33,17 @@ pub struct RedisTaskDb { config: RedisConfig, } +#[derive(Clone)] pub struct RedisTaskManager { arc_task_db: Arc>, } +impl std::fmt::Debug for RedisTaskManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RedisTaskManager") + } +} + type RedisDbResult = Result; #[derive(Error, Debug)] @@ -224,7 +231,7 @@ impl RedisTaskDb { } else { error!("Failed to deserialize TaskProvingStatusRecords"); Err(RedisDbError::TaskManager( - format!("Failed to deserialize TaskProvingStatusRecords").to_owned(), + "Failed to deserialize TaskProvingStatusRecords".to_string(), )) } } @@ -302,7 +309,7 @@ impl RedisTaskDb { key: &AggregationTaskDescriptor, new_status: TaskProvingStatus, ) -> RedisDbResult<()> { - let old_value = self.query_aggregation_task(key.into()).unwrap_or_default(); + let old_value = self.query_aggregation_task(key).unwrap_or_default(); let mut records = match old_value { Some(v) => v, None => { @@ -450,11 +457,15 @@ impl RedisTaskDb { ) { (Ok(desc), _) => { let status = self.query_proof_task_latest_status(&desc)?; - status.map(|s| kvs.push((TaskDescriptor::SingleProof(desc), s.0))); + if let Some(s) = status { + kvs.push((TaskDescriptor::SingleProof(desc), s.0)) + } } (_, Ok(desc)) => { let status = self.query_aggregation_task_latest_status(&desc)?; - status.map(|s| kvs.push((TaskDescriptor::Aggregation(desc), s.0))); + if let Some(s) = status { + kvs.push((TaskDescriptor::Aggregation(desc), s.0)) + } } _ => (), } @@ -465,7 +476,9 @@ impl RedisTaskDb { fn enqueue_aggregation_task(&mut self, request: &AggregationOnlyRequest) -> RedisDbResult<()> { let task_status = (TaskStatus::Registered, None, Utc::now()); - let agg_task_descriptor = request.into(); + let agg_task_descriptor = request + .try_into() + .map_err(|e: String| RedisDbError::TaskManager(e))?; match self.query_aggregation_task(&agg_task_descriptor)? { Some(task_proving_records) => { info!( @@ -488,7 +501,9 @@ impl RedisTaskDb { &mut self, request: &AggregationOnlyRequest, ) -> RedisDbResult { - let agg_task_descriptor = request.into(); + let agg_task_descriptor = request + .try_into() + .map_err(|e: String| RedisDbError::TaskManager(e))?; match self.query_aggregation_task(&agg_task_descriptor)? { Some(records) => Ok(records), None => Err(RedisDbError::KeyNotFound( @@ -503,7 +518,9 @@ impl RedisTaskDb { status: TaskStatus, proof: Option<&[u8]>, ) -> RedisDbResult<()> { - let agg_task_descriptor = request.into(); + let agg_task_descriptor = request + .try_into() + .map_err(|e: String| RedisDbError::TaskManager(e))?; match self.query_aggregation_task(&agg_task_descriptor)? { Some(records) => { if let Some(latest) = records.0.last() { @@ -528,7 +545,9 @@ impl RedisTaskDb { &mut self, request: &AggregationOnlyRequest, ) -> RedisDbResult> { - let agg_task_descriptor = request.into(); + let agg_task_descriptor = request + .try_into() + .map_err(|e: String| RedisDbError::TaskManager(e))?; let proving_status_records = self .query_aggregation_task(&agg_task_descriptor)? .unwrap_or_default(); @@ -563,11 +582,8 @@ impl RedisTaskDb { fn prune_aggregation(&mut self) -> RedisDbResult<()> { let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; for key in keys.iter() { - match AggregationTaskDescriptor::from_redis_value(key) { - Ok(desc) => { - self.delete_redis(&desc)?; - } - _ => (), + if let Ok(desc) = AggregationTaskDescriptor::from_redis_value(key) { + self.delete_redis(&desc)?; } } Ok(()) @@ -577,21 +593,18 @@ impl RedisTaskDb { let mut kvs: Vec = Vec::new(); let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; for key in keys.iter() { - match AggregationTaskDescriptor::from_redis_value(key) { - Ok(desc) => { - let status = self.query_aggregation_task_latest_status(&desc)?; - status.map(|s| { - kvs.push(( - AggregationOnlyRequest { - aggregation_ids: desc.aggregation_ids, - proof_type: desc.proof_type, - ..Default::default() - }, - s.0, - )) - }); + if let Ok(desc) = AggregationTaskDescriptor::from_redis_value(key) { + let status = self.query_aggregation_task_latest_status(&desc)?; + if let Some(s) = status { + kvs.push(( + AggregationOnlyRequest { + aggregation_ids: desc.aggregation_ids, + proof_type: desc.proof_type, + ..Default::default() + }, + s.0, + )); } - _ => (), } } Ok(kvs) @@ -623,12 +636,11 @@ impl RedisTaskDb { let mut kvs = Vec::new(); let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; for key in keys.iter() { - match TaskIdDescriptor::from_redis_value(key) { - Ok(desc) => { - let status = self.query_redis(&desc)?; - status.map(|s| kvs.push((desc.0, s))); + if let Ok(desc) = TaskIdDescriptor::from_redis_value(key) { + let status = self.query_redis(&desc)?; + if let Some(s) = status { + kvs.push((desc.0, s)); } - _ => (), } } Ok(kvs) @@ -637,11 +649,8 @@ impl RedisTaskDb { fn prune_stored_ids(&mut self) -> RedisDbResult<()> { let keys: Vec = self.get_conn()?.keys("*").map_err(RedisDbError::RedisDb)?; for key in keys.iter() { - match TaskIdDescriptor::from_redis_value(key) { - Ok(desc) => { - self.delete_redis(&desc)?; - } - _ => (), + if let Ok(desc) = TaskIdDescriptor::from_redis_value(key) { + self.delete_redis(&desc)?; } } Ok(()) @@ -674,25 +683,37 @@ impl IdWrite for RedisTaskManager { #[async_trait::async_trait] impl TaskManager for RedisTaskManager { + #[cfg(not(test))] fn new(opts: &TaskManagerOpts) -> Self { - static INIT: Once = Once::new(); + static INIT: std::sync::Once = std::sync::Once::new(); static mut REDIS_DB: Option>> = None; INIT.call_once(|| { unsafe { - REDIS_DB = Some(Arc::new(Mutex::new({ - let db = RedisTaskDb::new(RedisConfig { + REDIS_DB = Some(Arc::new(Mutex::new( + RedisTaskDb::new(RedisConfig { url: opts.redis_url.clone(), - ttl: opts.redis_ttl.clone(), + ttl: opts.redis_ttl, }) - .unwrap(); - db - }))) + .unwrap(), + ))) }; }); Self { arc_task_db: unsafe { REDIS_DB.clone().unwrap() }, } } + #[cfg(test)] + fn new(opts: &TaskManagerOpts) -> Self { + Self { + arc_task_db: Arc::new(Mutex::new( + RedisTaskDb::new(RedisConfig { + url: opts.redis_url.clone(), + ttl: opts.redis_ttl.clone(), + }) + .unwrap(), + )), + } + } async fn enqueue_task( &mut self, @@ -840,6 +861,7 @@ mod tests { blockhash: B256::default(), proof_system: ProofType::Native, prover: "0x1234".to_owned(), + image_id: None, }; db.enqueue_task(¶ms).expect("enqueue task failed"); let status = db.get_task_proving_status(¶ms); @@ -859,6 +881,7 @@ mod tests { blockhash: B256::default(), proof_system: ProofType::Native, prover: "0x1234".to_owned(), + image_id: None, }; db.enqueue_task(¶ms).expect("enqueue task failed"); let status = db.get_task_proving_status(¶ms); diff --git a/taskdb/tests/main.rs b/taskdb/tests/main.rs index 74b2a91b2..8cc759777 100644 --- a/taskdb/tests/main.rs +++ b/taskdb/tests/main.rs @@ -16,7 +16,7 @@ mod tests { use raiko_lib::{input::BlobProofType, primitives::B256, proof_type::ProofType}; use raiko_tasks::{ - get_task_manager, ProofTaskDescriptor, TaskManager, TaskManagerOpts, TaskStatus, + ProofTaskDescriptor, TaskManager, TaskManagerOpts, TaskManagerWrapperImpl, TaskStatus, }; fn create_random_task(rng: &mut ChaCha8Rng) -> (u64, B256, ProofRequest) { @@ -46,13 +46,14 @@ mod tests { prover_args, blob_proof_type: BlobProofType::ProofOfEquivalence, l1_inclusion_block_number: 0, + image_id: Some("test_image".to_string()), }, ) } #[tokio::test] async fn test_enqueue_task() { - let mut tama = get_task_manager(&TaskManagerOpts { + let mut tama = TaskManagerWrapperImpl::new(&TaskManagerOpts { max_db_size: 1_000_000, redis_url: env::var("REDIS_URL").unwrap_or_default(), redis_ttl: 3600, @@ -60,23 +61,20 @@ mod tests { let (chain_id, blockhash, request) = create_random_task(&mut ChaCha8Rng::seed_from_u64(123)); - tama.enqueue_task( - &( - chain_id, - request.block_number, - blockhash, - request.proof_type, - request.prover.to_string(), - ) - .into(), - ) - .await - .unwrap(); + let task = ProofTaskDescriptor::new( + chain_id.into(), + request.block_number, + blockhash, + request.proof_type, + request.prover.to_string(), + request.image_id.clone(), + ); + tama.enqueue_task(&task).await.unwrap(); } #[tokio::test] async fn test_update_query_tasks_progress() { - let mut tama = get_task_manager(&TaskManagerOpts { + let mut tama = TaskManagerWrapperImpl::new(&TaskManagerOpts { max_db_size: 1_000_000, redis_url: env::var("REDIS_URL").unwrap_or_default(), redis_ttl: 3600, @@ -88,60 +86,49 @@ mod tests { for _ in 0..5 { let (chain_id, blockhash, request) = create_random_task(&mut rng); - tama.enqueue_task( - &( - chain_id, - request.block_number, - blockhash, - request.proof_type, - request.prover.to_string(), - ) - .into(), - ) - .await - .unwrap(); + let task = ProofTaskDescriptor::new( + chain_id.into(), + request.block_number, + blockhash, + request.proof_type, + request.prover.to_string(), + request.image_id.clone(), + ); + tasks.push(task.clone()); - let task_status = tama - .get_task_proving_status( - &( - chain_id, - request.block_number, - blockhash, - request.proof_type, - request.prover.to_string(), - ) - .into(), - ) - .await - .unwrap() - .0; + tama.enqueue_task(&task).await.unwrap(); + + let task_desc = ProofTaskDescriptor::new( + chain_id.into(), + request.block_number, + blockhash, + request.proof_type, + request.prover.to_string(), + request.image_id.clone(), + ); + let task_status = tama.get_task_proving_status(&task_desc).await.unwrap().0; assert_eq!(task_status.len(), 1); let status = task_status .first() .expect("Already confirmed there is exactly 1 element"); assert_eq!(status.0, TaskStatus::Registered); - tasks.push(( - chain_id, - blockhash, + let task = ProofTaskDescriptor::new( + chain_id.into(), request.block_number, + blockhash, request.proof_type, - request.prover, - )); + request.prover.to_string(), + request.image_id.clone(), + ); + tasks.push(task); } std::thread::sleep(Duration::from_millis(1)); { - let task_0_desc: &ProofTaskDescriptor = &( - tasks[0].0, - tasks[0].2, - tasks[0].1, - tasks[0].3, - tasks[0].4.to_string(), - ) - .into(); - let task_status = tama.get_task_proving_status(task_0_desc).await.unwrap().0; + let task_0_desc = tasks[0].clone(); + let task_status = tama.get_task_proving_status(&task_0_desc).await.unwrap().0; println!("{task_status:?}"); tama.update_task_progress( task_0_desc.clone(), @@ -151,7 +138,7 @@ mod tests { .await .unwrap(); - let task_status = tama.get_task_proving_status(task_0_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_0_desc).await.unwrap().0; println!("{task_status:?}"); assert_eq!(task_status.len(), 2); assert_eq!(task_status[1].0, TaskStatus::Cancelled_NeverStarted); @@ -159,21 +146,14 @@ mod tests { } // ----------------------- { - let task_1_desc: &ProofTaskDescriptor = &( - tasks[1].0, - tasks[1].2, - tasks[1].1, - tasks[1].3, - tasks[1].4.to_string(), - ) - .into(); + let task_1_desc = tasks[1].clone(); tama.update_task_progress(task_1_desc.clone(), TaskStatus::WorkInProgress, None) .await .unwrap(); { - let task_status = tama.get_task_proving_status(task_1_desc).await.unwrap().0; - assert_eq!(task_status.len(), 2); + let task_status = tama.get_task_proving_status(&task_1_desc).await.unwrap().0; + assert_eq!(task_status.len(), 2, "task_status: {:?}", task_status); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); assert_eq!(task_status[0].0, TaskStatus::Registered); } @@ -189,7 +169,7 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_1_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_1_desc).await.unwrap().0; assert_eq!(task_status.len(), 3); assert_eq!(task_status[2].0, TaskStatus::CancellationInProgress); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); @@ -203,7 +183,7 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_1_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_1_desc).await.unwrap().0; assert_eq!(task_status.len(), 4); assert_eq!(task_status[3].0, TaskStatus::Cancelled); assert_eq!(task_status[2].0, TaskStatus::CancellationInProgress); @@ -214,20 +194,13 @@ mod tests { // ----------------------- { - let task_2_desc: &ProofTaskDescriptor = &( - tasks[2].0, - tasks[2].2, - tasks[2].1, - tasks[2].3, - tasks[2].4.to_string(), - ) - .into(); + let task_2_desc = tasks[2].clone(); tama.update_task_progress(task_2_desc.clone(), TaskStatus::WorkInProgress, None) .await .unwrap(); { - let task_status = tama.get_task_proving_status(task_2_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_2_desc).await.unwrap().0; assert_eq!(task_status.len(), 2); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); assert_eq!(task_status[0].0, TaskStatus::Registered); @@ -241,32 +214,25 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_2_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_2_desc).await.unwrap().0; assert_eq!(task_status.len(), 3); assert_eq!(task_status[2].0, TaskStatus::Success); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); assert_eq!(task_status[0].0, TaskStatus::Registered); } - assert_eq!(proof, tama.get_task_proof(task_2_desc,).await.unwrap()); + assert_eq!(proof, tama.get_task_proof(&task_2_desc).await.unwrap()); } // ----------------------- { - let task_3_desc: &ProofTaskDescriptor = &( - tasks[3].0, - tasks[3].2, - tasks[3].1, - tasks[3].3, - tasks[3].4.to_string(), - ) - .into(); + let task_3_desc = tasks[3].clone(); tama.update_task_progress(task_3_desc.clone(), TaskStatus::WorkInProgress, None) .await .unwrap(); { - let task_status = tama.get_task_proving_status(task_3_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_3_desc).await.unwrap().0; assert_eq!(task_status.len(), 2); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); assert_eq!(task_status[0].0, TaskStatus::Registered); @@ -283,7 +249,7 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_3_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_3_desc).await.unwrap().0; assert_eq!(task_status.len(), 3); assert_eq!(task_status[2].0, TaskStatus::UnspecifiedFailureReason); assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); @@ -297,7 +263,7 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_3_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_3_desc).await.unwrap().0; assert_eq!(task_status.len(), 4); assert_eq!(task_status[3].0, TaskStatus::WorkInProgress); assert_eq!(task_status[2].0, TaskStatus::UnspecifiedFailureReason); @@ -317,7 +283,7 @@ mod tests { .unwrap(); { - let task_status = tama.get_task_proving_status(task_3_desc).await.unwrap().0; + let task_status = tama.get_task_proving_status(&task_3_desc).await.unwrap().0; assert_eq!(task_status.len(), 5); assert_eq!(task_status[4].0, TaskStatus::Success); assert_eq!(task_status[3].0, TaskStatus::WorkInProgress); @@ -326,7 +292,7 @@ mod tests { assert_eq!(task_status[0].0, TaskStatus::Registered); } - assert_eq!(proof, tama.get_task_proof(task_3_desc,).await.unwrap()); + assert_eq!(proof, tama.get_task_proof(&task_3_desc).await.unwrap()); } } }