Skip to content

Commit

Permalink
fix: make begin,commit,rollback cancel-safe in sqlite (#2054) (#2057)
Browse files Browse the repository at this point in the history
  • Loading branch information
madadam authored Sep 13, 2022
1 parent 09717e1 commit f38c739
Showing 1 changed file with 107 additions and 10 deletions.
117 changes: 107 additions & 10 deletions sqlx-core/src/sqlite/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ enum Command {
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
},
Begin {
tx: oneshot::Sender<Result<(), Error>>,
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
},
Commit {
tx: oneshot::Sender<Result<(), Error>>,
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
},
Rollback {
tx: Option<oneshot::Sender<Result<(), Error>>>,
tx: Option<rendezvous_oneshot::Sender<Result<(), Error>>>,
},
CreateCollation {
create_collation:
Expand Down Expand Up @@ -116,6 +116,11 @@ impl ConnectionWorker {
return;
}

// If COMMIT or ROLLBACK is processed but not acknowledged, there would be another
// ROLLBACK sent when the `Transaction` drops. We need to ignore it otherwise we
// would rollback an already completed transaction.
let mut ignore_next_start_rollback = false;

for cmd in command_rx {
match cmd {
Command::Prepare { query, tx } => {
Expand Down Expand Up @@ -162,8 +167,27 @@ impl ConnectionWorker {
.map(|_| {
conn.transaction_depth += 1;
});

tx.send(res).ok();
let res_ok = res.is_ok();

if tx.blocking_send(res).is_err() && res_ok {
// The BEGIN was processed but not acknowledged. This means no
// `Transaction` was created and so there is no way to commit /
// rollback this transaction. We need to roll it back
// immediately otherwise it would remain started forever.
if let Err(e) = conn
.handle
.exec(rollback_ansi_transaction_sql(depth + 1))
.map(|_| {
conn.transaction_depth -= 1;
})
{
// The rollback failed. To prevent leaving the connection
// in an inconsistent state we shutdown this worker which
// causes any subsequent operation on the connection to fail.
log::error!("failed to rollback cancelled transaction: {}", e);
break;
}
}
}
Command::Commit { tx } => {
let depth = conn.transaction_depth;
Expand All @@ -177,10 +201,21 @@ impl ConnectionWorker {
} else {
Ok(())
};
let res_ok = res.is_ok();

tx.send(res).ok();
if tx.blocking_send(res).is_err() && res_ok {
// The COMMIT was processed but not acknowledged. This means that
// the `Transaction` doesn't know it was committed and will try to
// rollback on drop. We need to ignore that rollback.
ignore_next_start_rollback = true;
}
}
Command::Rollback { tx } => {
if ignore_next_start_rollback && tx.is_none() {
ignore_next_start_rollback = false;
continue;
}

let depth = conn.transaction_depth;

let res = if depth > 0 {
Expand All @@ -193,8 +228,16 @@ impl ConnectionWorker {
Ok(())
};

let res_ok = res.is_ok();

if let Some(tx) = tx {
tx.send(res).ok();
if tx.blocking_send(res).is_err() && res_ok {
// The ROLLBACK was processed but not acknowledged. This means
// that the `Transaction` doesn't know it was rolled back and
// will try to rollback again on drop. We need to ignore that
// rollback.
ignore_next_start_rollback = true;
}
}
}
Command::CreateCollation { create_collation } => {
Expand Down Expand Up @@ -268,15 +311,17 @@ impl ConnectionWorker {
}

pub(crate) async fn begin(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Begin { tx }).await?
self.oneshot_cmd_with_ack(|tx| Command::Begin { tx })
.await?
}

pub(crate) async fn commit(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
self.oneshot_cmd_with_ack(|tx| Command::Commit { tx })
.await?
}

pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
self.oneshot_cmd_with_ack(|tx| Command::Rollback { tx: Some(tx) })
.await?
}

Expand Down Expand Up @@ -304,6 +349,20 @@ impl ConnectionWorker {
rx.await.map_err(|_| Error::WorkerCrashed)
}

async fn oneshot_cmd_with_ack<F, T>(&mut self, command: F) -> Result<T, Error>
where
F: FnOnce(rendezvous_oneshot::Sender<T>) -> Command,
{
let (tx, rx) = rendezvous_oneshot::channel();

self.command_tx
.send_async(command(tx))
.await
.map_err(|_| Error::WorkerCrashed)?;

rx.recv().await.map_err(|_| Error::WorkerCrashed)
}

pub fn create_collation(
&mut self,
name: &str,
Expand Down Expand Up @@ -387,3 +446,41 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
size.store(conn.statements.len(), Ordering::Release);
}

// A oneshot channel where send completes only after the receiver receives the value.
mod rendezvous_oneshot {
use super::oneshot::{self, Canceled};

pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let (inner_tx, inner_rx) = oneshot::channel();
(Sender { inner: inner_tx }, Receiver { inner: inner_rx })
}

pub struct Sender<T> {
inner: oneshot::Sender<(T, oneshot::Sender<()>)>,
}

impl<T> Sender<T> {
pub async fn send(self, value: T) -> Result<(), Canceled> {
let (ack_tx, ack_rx) = oneshot::channel();
self.inner.send((value, ack_tx)).map_err(|_| Canceled)?;
ack_rx.await
}

pub fn blocking_send(self, value: T) -> Result<(), Canceled> {
futures_executor::block_on(self.send(value))
}
}

pub struct Receiver<T> {
inner: oneshot::Receiver<(T, oneshot::Sender<()>)>,
}

impl<T> Receiver<T> {
pub async fn recv(self) -> Result<T, Canceled> {
let (value, ack_tx) = self.inner.await?;
ack_tx.send(()).map_err(|_| Canceled)?;
Ok(value)
}
}
}

0 comments on commit f38c739

Please # to comment.