diff --git a/std/src/sync/rwlock/tests.rs b/std/src/sync/rwlock/tests.rs index 37a2e41641ac1..a4af49dc82cce 100644 --- a/std/src/sync/rwlock/tests.rs +++ b/std/src/sync/rwlock/tests.rs @@ -3,8 +3,8 @@ use rand::Rng; use crate::sync::atomic::{AtomicUsize, Ordering}; use crate::sync::mpsc::channel; use crate::sync::{ - Arc, MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, - TryLockError, + Arc, Barrier, MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, + RwLockWriteGuard, TryLockError, }; use crate::thread; @@ -501,3 +501,123 @@ fn panic_while_mapping_write_unlocked_poison() { drop(lock); } + +#[test] +fn test_downgrade_basic() { + let r = RwLock::new(()); + + let write_guard = r.write().unwrap(); + let _read_guard = RwLockWriteGuard::downgrade(write_guard); +} + +#[test] +fn test_downgrade_readers() { + // This test creates 1 writing thread and `R` reader threads doing `N` iterations. + const R: usize = 10; + const N: usize = if cfg!(target_pointer_width = "64") { 100 } else { 20 }; + + // The writer thread will constantly update the value inside the `RwLock`, and this test will + // only pass if every reader observes all values between 0 and `N`. + let rwlock = Arc::new(RwLock::new(0)); + let barrier = Arc::new(Barrier::new(R + 1)); + + // Create the writing thread. + let r_writer = rwlock.clone(); + let b_writer = barrier.clone(); + thread::spawn(move || { + for i in 0..N { + let mut write_guard = r_writer.write().unwrap(); + *write_guard = i; + + let read_guard = RwLockWriteGuard::downgrade(write_guard); + assert_eq!(*read_guard, i); + + // Wait for all readers to observe the new value. + b_writer.wait(); + } + }); + + for _ in 0..R { + let rwlock = rwlock.clone(); + let barrier = barrier.clone(); + thread::spawn(move || { + // Every reader thread needs to observe every value up to `N`. + for i in 0..N { + let read_guard = rwlock.read().unwrap(); + assert_eq!(*read_guard, i); + drop(read_guard); + + // Wait for everyone to read and for the writer to change the value again. + barrier.wait(); + + // Spin until the writer has changed the value. + loop { + let read_guard = rwlock.read().unwrap(); + assert!(*read_guard >= i); + + if *read_guard > i { + break; + } + } + } + }); + } +} + +#[test] +fn test_downgrade_atomic() { + const NEW_VALUE: i32 = -1; + + // This test checks that `downgrade` is atomic, meaning as soon as a write lock has been + // downgraded, the lock must be in read mode and no other threads can take the write lock to + // modify the protected value. + + // `W` is the number of evil writer threads. + const W: usize = if cfg!(target_pointer_width = "64") { 100 } else { 20 }; + let rwlock = Arc::new(RwLock::new(0)); + + // Spawns many evil writer threads that will try and write to the locked value before the + // initial writer (who has the exclusive lock) can read after it downgrades. + // If the `RwLock` behaves correctly, then the initial writer should read the value it wrote + // itself as no other thread should be able to mutate the protected value. + + // Put the lock in write mode, causing all future threads trying to access this go to sleep. + let mut main_write_guard = rwlock.write().unwrap(); + + // Spawn all of the evil writer threads. They will each increment the protected value by 1. + let handles: Vec<_> = (0..W) + .map(|_| { + let rwlock = rwlock.clone(); + thread::spawn(move || { + // Will go to sleep since the main thread initially has the write lock. + let mut evil_guard = rwlock.write().unwrap(); + *evil_guard += 1; + }) + }) + .collect(); + + // Wait for a good amount of time so that evil threads go to sleep. + // Note: this is not strictly necessary... + let eternity = crate::time::Duration::from_millis(42); + thread::sleep(eternity); + + // Once everyone is asleep, set the value to `NEW_VALUE`. + *main_write_guard = NEW_VALUE; + + // Atomically downgrade the write guard into a read guard. + let main_read_guard = RwLockWriteGuard::downgrade(main_write_guard); + + // If the above is not atomic, then it would be possible for an evil thread to get in front of + // this read and change the value to be non-negative. + assert_eq!(*main_read_guard, NEW_VALUE, "`downgrade` was not atomic"); + + // Drop the main read guard and allow the evil writer threads to start incrementing. + drop(main_read_guard); + + for handle in handles { + handle.join().unwrap(); + } + + let final_check = rwlock.read().unwrap(); + assert_eq!(*final_check, W as i32 + NEW_VALUE); +}