Skip to content

Commit

Permalink
Fix rayon and arc usage
Browse files Browse the repository at this point in the history
  • Loading branch information
SHAcollision committed Nov 17, 2024
1 parent 95a1bbc commit 057de7f
Showing 1 changed file with 37 additions and 34 deletions.
71 changes: 37 additions & 34 deletions examples/mark_recapture_dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
use dashmap::DashSet;
use mainline::{Dht, Id};
use rayon::{prelude::*, ThreadPoolBuilder};
use std::sync::Arc;
use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder};
use tracing::{debug, info, Level};

/// Adjust as needed. Default will take about ~2 hours
Expand All @@ -59,6 +58,7 @@ const MIN_OVERLAP: usize = 10_000;
const MAX_RANDOM_NODE_IDS: usize = 100_000;
// Number of parallel lookups. Ideally not bigger than number of threads available. Display progress every N.
const BATCH_SIZE: usize = 16;
const Z_SCORE: f64 = 1.96;

/// Represents the DHT size estimation result.
struct EstimateResult {
Expand Down Expand Up @@ -104,25 +104,24 @@ fn main() {
.build()
.expect("Failed to build Rayon thread pool");

pool.install(|| {
// Initialize the DHT client wrapped in Arc for thread safety.
let dht = Arc::new(Dht::client().expect("Failed to create DHT client"));
// Initialize the DHT client.
let dht = Dht::client().expect("Failed to create DHT client");

// Collect samples from the DHT.
let (marked_sample, recapture_sample) =
collect_samples(&dht, MIN_OVERLAP, MAX_RANDOM_NODE_IDS);
// Collect samples from the DHT.
let (marked_sample, recapture_sample) =
collect_samples(&dht, pool, MIN_OVERLAP, MAX_RANDOM_NODE_IDS);

// Display the final statistics.
if let Some(estimate) = compute_estimate(&marked_sample, &recapture_sample) {
estimate.display();
} else {
println!("Unable to calculate the DHT size estimate due to insufficient overlap.");
}
});
// Display the final statistics.
if let Some(estimate) = compute_estimate(&marked_sample, &recapture_sample) {
estimate.display();
} else {
println!("Unable to calculate the DHT size estimate due to insufficient overlap.");
}
}

fn collect_samples(
dht: &Arc<Dht>,
dht: &Dht,
pool: ThreadPool,
min_overlap: usize,
max_unique_random_node_ids: usize,
) -> (DashSet<Id>, DashSet<Id>) {
Expand All @@ -135,27 +134,32 @@ fn collect_samples(

loop {
if total_iterations >= max_unique_random_node_ids {
println!("Reached maximum number of random node ID lookups.");
break;
}

// Sample for marked_sample in parallel.
let random_ids: Vec<_> = (0..BATCH_SIZE).map(|_| Id::random()).collect();
random_ids.par_iter().for_each(|random_id| {
if let Ok(nodes) = dht.find_node(*random_id) {
for node in nodes {
marked_sample.insert(*node.id());
let mark_random_ids: Vec<_> = (0..BATCH_SIZE).map(|_| Id::random()).collect();
let recapture_random_ids: Vec<_> = (0..BATCH_SIZE).map(|_| Id::random()).collect();

// Perform sampling in the thread pool
pool.install(|| {
// Sample for marked_sample in parallel.
mark_random_ids.par_iter().for_each(|random_id| {
if let Ok(nodes) = dht.find_node(*random_id) {
for node in nodes {
marked_sample.insert(*node.id());
}
}
}
});
});

// Sample for recapture_sample in parallel.
let random_ids: Vec<_> = (0..BATCH_SIZE).map(|_| Id::random()).collect();
random_ids.par_iter().for_each(|random_id| {
if let Ok(nodes) = dht.find_node(*random_id) {
for node in nodes {
recapture_sample.insert(*node.id());
// Sample for recapture_sample in parallel.
recapture_random_ids.par_iter().for_each(|random_id| {
if let Ok(nodes) = dht.find_node(*random_id) {
for node in nodes {
recapture_sample.insert(*node.id());
}
}
}
});
});

total_iterations += BATCH_SIZE;
Expand All @@ -168,7 +172,7 @@ fn collect_samples(

if let Some(estimate) = compute_estimate(&marked_sample, &recapture_sample) {
size = estimate.estimate;
confidence = estimate.standard_error * 1.96;
confidence = estimate.standard_error * Z_SCORE;
}

info!(
Expand Down Expand Up @@ -228,8 +232,7 @@ fn compute_estimate(
let standard_error = variance.sqrt();

// 95% confidence interval.
let z_score = 1.96;
let margin_of_error = z_score * standard_error;
let margin_of_error = Z_SCORE * standard_error;
let lower_bound = (estimate - margin_of_error).max(0.0);
let upper_bound = estimate + margin_of_error;

Expand Down

0 comments on commit 057de7f

Please # to comment.