From 6d7e1deed352d3a1656cd55aef8e3aeb711c3b08 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Mon, 17 Feb 2025 16:32:20 +0800 Subject: [PATCH] chore: apply review suggestion --- host/src/server/api/v2/proof/mod.rs | 79 ++++++++++++++++++----------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index c4ee97e0..6d26bad2 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -1,5 +1,5 @@ use axum::{extract::State, routing::post, Json, Router}; -use raiko_core::interfaces::RaikoError; +use raiko_core::interfaces::{ProofRequestOpt, RaikoError}; use raiko_core::{interfaces::ProofRequest, provider::get_task_data}; use raiko_lib::proof_type::ProofType; use raiko_reqpool::{SingleProofRequestEntity, SingleProofRequestKey}; @@ -45,21 +45,8 @@ async fn proof_handler(State(actor): State, Json(req): Json) -> Ho config.merge(&req)?; // For zk_any request, draw zk proof type based on the block hash. - // - // A zk_any request looks like: { "proof_type": "zk_any", "zk_any": { "aggregation": } } - let is_zk_any = config.proof_type == Some("zk_any".to_string()); - if is_zk_any { - let network = config - .network - .as_ref() - .ok_or(RaikoError::InvalidRequestConfig( - "Missing network".to_string(), - ))?; - let block_number = config.block_number.ok_or(RaikoError::InvalidRequestConfig( - "Missing block number".to_string(), - ))?; - let (_, blockhash) = get_task_data(&network, block_number, actor.chain_specs()).await?; - match actor.draw(&blockhash) { + if is_zk_any_request(&config) { + match draw_for_zk_any_request(&actor, &config).await? { Some(proof_type) => config.proof_type = Some(proof_type.to_string()), None => { return Ok(Status::Ok { @@ -70,21 +57,13 @@ async fn proof_handler(State(actor): State, Json(req): Json) -> Ho }); } } - } - if is_zk_any && config.proof_type == Some(ProofType::Sp1.to_string()) { - // Parse req, extract the aggregation field - // { "proof_type": "zk_any", "zk_any": { "aggregation": } } - let aggregation = req["zk_any"]["aggregation"].as_bool().unwrap_or(false); - let mut sp1_opts = config - .prover_args - .sp1 - .expect("config.merge() should have set sp1"); - if aggregation { - sp1_opts["recursion"] = serde_json::Value::String("compressed".to_string()); - } else { - sp1_opts["recursion"] = serde_json::Value::String("plonk".to_string()); + // Specially process zk_any requests with sp1 parameters. + if config.proof_type == Some(ProofType::Sp1.to_string()) { + // Parse req, extract the aggregation field + // { "proof_type": "zk_any", "zk_any": { "aggregation": } } + let sp1_opts = sp1_params_for_zk_any_request(&req, &config); + config.prover_args.sp1 = Some(sp1_opts); } - config.prover_args.sp1 = Some(sp1_opts); } // Construct the actual proof request from the available configs. @@ -151,3 +130,43 @@ pub fn create_router() -> Router { .nest("/list", list::create_router()) .nest("/prune", prune::create_router()) } + +// A zk_any request looks like: { "proof_type": "zk_any", "zk_any": { "aggregation": } } +fn is_zk_any_request(proof_request_opt: &ProofRequestOpt) -> bool { + proof_request_opt.proof_type == Some("zk_any".to_string()) +} + +async fn draw_for_zk_any_request( + actor: &Actor, + proof_request_opt: &ProofRequestOpt, +) -> HostResult> { + let network = proof_request_opt + .network + .as_ref() + .ok_or(RaikoError::InvalidRequestConfig( + "Missing network".to_string(), + ))?; + let block_number = proof_request_opt + .block_number + .ok_or(RaikoError::InvalidRequestConfig( + "Missing block number".to_string(), + ))?; + let (_, blockhash) = get_task_data(&network, block_number, actor.chain_specs()).await?; + Ok(actor.draw(&blockhash)) +} + +fn sp1_params_for_zk_any_request(req: &Value, proof_request_opt: &ProofRequestOpt) -> Value { + let aggregation = req["zk_any"]["aggregation"].as_bool().unwrap_or(false); + let mut sp1_opts = proof_request_opt + .prover_args + .sp1 + .as_ref() + .expect("config.merge() should have set sp1") + .to_owned(); + if aggregation { + sp1_opts["recursion"] = serde_json::Value::String("compressed".to_string()); + } else { + sp1_opts["recursion"] = serde_json::Value::String("plonk".to_string()); + } + sp1_opts +}