Skip to content

Custom executor example #1

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ testcontainers = "0.23"
testcontainers-modules = { version = "0.11", features = ["localstack"] }
tokio = { version = "1.43.0", features = ["rt-multi-thread", "macros"] }
tokio-stream = "0.1.17"
rayon = "1.10.0"
async-task = "4.7.1"

[[bin]]
name = "single_runtime"
Expand All @@ -21,3 +23,7 @@ path = "src/single_runtime.rs"
[[bin]]
name = "two_runtimes"
path = "src/two_runtimes.rs"

[[bin]]
name = "custom_runtime"
path = "src/custom_runtime.rs"
97 changes: 97 additions & 0 deletions src/custom_runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::{iter, pin::Pin, sync::Arc, time::Duration};

use crate::executor::AsyncExecutor;
use futures::{
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
};
use object_store::{aws::AmazonS3Builder, ObjectStore, PutPayload, Result};

mod executor;
mod localstack;

static CPU_TIME: u64 = 2;
static N_FILES: usize = 2;
static OBJECT_KEY: &str = "test";
static N_IO_THREADS: usize = 2;

#[tokio::main]
async fn main() {
let num_threads = std::thread::available_parallelism().unwrap().get();

// Start localstack container
let localstack = localstack::localstack_container().await;
let localstack_host = localstack.get_host().await.unwrap();
let localstack_port = localstack.get_host_port_ipv4(4566).await.unwrap();

let object_store: Arc<dyn ObjectStore> = Arc::new(
AmazonS3Builder::new()
.with_config("aws_access_key_id".parse().unwrap(), "user")
.with_config("aws_secret_access_key".parse().unwrap(), "password")
.with_config(
"endpoint".parse().unwrap(),
format!("http://{}:{}", localstack_host, localstack_port),
)
.with_config("region".parse().unwrap(), "us-east-1")
.with_config("allow_http".parse().unwrap(), "true")
.with_config("bucket".parse().unwrap(), "warehouse")
.build()
.unwrap(),
);

// Insert object
object_store
.put(
&OBJECT_KEY.into(),
PutPayload::from_static(&[0; 10 * 1024 * 1024]),
)
.await
.unwrap();

let executor = AsyncExecutor::new();
let mut handles = vec![];

// Leave two cores unoccupied
for _ in 0..(num_threads - N_IO_THREADS) {
let handle = executor.spawn({
let object_store = object_store.clone();
async move {
execution_stream(object_store)
.try_collect::<Vec<Vec<_>>>()
.await
.unwrap();
}
});
handles.push(handle);
}

futures::future::join_all(handles).await;
}

fn execution_stream(
object_store: Arc<dyn ObjectStore>,
) -> Pin<Box<dyn Stream<Item = Result<Vec<u8>, object_store::Error>> + Send>> {
Box::pin(io_stream(object_store).map_ok(cpu_work).map_ok(cpu_work))
}

fn io_stream(
object_store: Arc<dyn ObjectStore>,
) -> BoxStream<'static, Result<Vec<u8>, object_store::Error>> {
Box::pin(
stream::iter(iter::repeat_n(object_store, N_FILES))
.then(|object_store| async move {
object_store
.get(&OBJECT_KEY.into())
.await
.unwrap()
.into_stream()
.map_ok(|x| Vec::from(x))
})
.flatten(),
)
}

fn cpu_work(bytes: Vec<u8>) -> Vec<u8> {
std::thread::sleep(Duration::from_secs(CPU_TIME));
bytes
}
77 changes: 77 additions & 0 deletions src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use async_task::Runnable;
use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

/// An executor designed to run potentially blocking futures
pub struct AsyncExecutor {
io: tokio::runtime::Handle,
cpu: Arc<rayon::ThreadPool>,
}

impl AsyncExecutor {
pub fn new() -> Self {
let io = tokio::runtime::Handle::current();
let cpu = rayon::ThreadPoolBuilder::new()
.num_threads(8)
.use_current_thread()
.build()
.unwrap();

let cpu = Arc::new(cpu);
Self { io, cpu }
}

pub fn spawn<F>(&self, fut: F) -> SpawnHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (sender, receiver) = futures::channel::oneshot::channel();
let handle = self.io.clone();

// This box is technically unnecessary, but avoids some pin shenanigans
let mut boxed = Box::pin(fut);

// Enter tokio runtime whilst polling future - allowing IO and timers to work
let io_fut = futures::future::poll_fn(move |cx| {
let _guard = handle.enter();
boxed.poll_unpin(cx)
});
// Route result back to oneshot
let remote_fut = io_fut.map(|out| {
let _ = sender.send(out);
});

// Task execution is scheduled on rayon
let cpu = self.cpu.clone();
let (runnable, task) = async_task::spawn(remote_fut, move |runnable: Runnable<()>| {
cpu.spawn(move || {
let _ = runnable.run();
});
});
runnable.schedule();
SpawnHandle {
_task: task,
receiver,
}
}
}

/// Handle returned by [`AsyncExecutor`]
///
/// Cancels task on drop
pub struct SpawnHandle<T> {
receiver: futures::channel::oneshot::Receiver<T>,
_task: async_task::Task<()>,
}

impl<T> Future for SpawnHandle<T> {
type Output = Option<T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.receiver.poll_unpin(cx).map(|x| x.ok())
}
}