diff --git a/evm/src/executor/inspector/cheatcodes/env.rs b/evm/src/executor/inspector/cheatcodes/env.rs
index da688b8a6edf0..9c9609f724def 100644
--- a/evm/src/executor/inspector/cheatcodes/env.rs
+++ b/evm/src/executor/inspector/cheatcodes/env.rs
@@ -51,6 +51,39 @@ pub struct Prank {
pub depth: u64,
/// Whether the prank stops by itself after the next call
pub single_call: bool,
+ /// Whether the prank has been used yet (false if unused)
+ pub used: bool,
+}
+
+impl Prank {
+ pub fn new(
+ prank_caller: Address,
+ prank_origin: Address,
+ new_caller: Address,
+ new_origin: Option
,
+ depth: u64,
+ single_call: bool,
+ ) -> Prank {
+ Prank {
+ prank_caller,
+ prank_origin,
+ new_caller,
+ new_origin,
+ depth,
+ single_call,
+ used: false,
+ }
+ }
+
+ /// Apply the prank by setting `used` to true iff it is false
+ /// Only returns self in the case it is updated (first application)
+ pub fn first_time_applied(&self) -> Option {
+ if self.used {
+ None
+ } else {
+ Some(Prank { used: true, ..self.clone() })
+ }
+ }
}
/// Sets up broadcasting from a script using `origin` as the sender
@@ -102,14 +135,18 @@ fn prank(
depth: u64,
single_call: bool,
) -> Result {
+ let prank = Prank::new(prank_caller, prank_origin, new_caller, new_origin, depth, single_call);
+
+ if let Some(Prank { used, .. }) = state.prank {
+ ensure!(used, "You cannot overwrite `prank` until it is applied at least once");
+ }
+
ensure!(
state.broadcast.is_none(),
- "You cannot `prank` for a broadcasted transaction. \
+ "You cannot `prank` for a broadcasted transaction.\
Pass the desired tx.origin into the broadcast cheatcode call"
);
- ensure!(state.prank.is_none(), "You have an active prank already.");
- let prank = Prank { prank_caller, prank_origin, new_caller, new_origin, depth, single_call };
state.prank = Some(prank);
Ok(Bytes::new())
}
diff --git a/evm/src/executor/inspector/cheatcodes/mod.rs b/evm/src/executor/inspector/cheatcodes/mod.rs
index ce824d935a09a..345e320f1dedc 100644
--- a/evm/src/executor/inspector/cheatcodes/mod.rs
+++ b/evm/src/executor/inspector/cheatcodes/mod.rs
@@ -607,15 +607,25 @@ where
if data.journaled_state.depth() >= prank.depth &&
call.context.caller == h160_to_b160(prank.prank_caller)
{
+ let mut prank_applied = false;
// At the target depth we set `msg.sender`
if data.journaled_state.depth() == prank.depth {
call.context.caller = h160_to_b160(prank.new_caller);
call.transfer.source = h160_to_b160(prank.new_caller);
+ prank_applied = true;
}
// At the target depth, or deeper, we set `tx.origin`
if let Some(new_origin) = prank.new_origin {
data.env.tx.caller = h160_to_b160(new_origin);
+ prank_applied = true;
+ }
+
+ // If prank applied for first time, then update
+ if prank_applied {
+ if let Some(applied_prank) = prank.first_time_applied() {
+ self.prank = Some(applied_prank);
+ }
}
}
}
diff --git a/testdata/cheats/Prank.t.sol b/testdata/cheats/Prank.t.sol
index fccccff7c9697..0dddb2ffe702d 100644
--- a/testdata/cheats/Prank.t.sol
+++ b/testdata/cheats/Prank.t.sol
@@ -118,6 +118,159 @@ contract PrankTest is DSTest {
);
}
+ function testPrank1AfterPrank0(address sender, address origin) public {
+ // Perform the prank
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.prank(sender);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was not set during prank"
+ );
+
+ // Ensure we cleaned up correctly
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+
+ // Overwrite the prank
+ cheats.prank(sender, origin);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", origin, "tx.origin invariant failed"
+ );
+
+ // Ensure we cleaned up correctly
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+ }
+
+ function testPrank0AfterPrank1(address sender, address origin) public {
+ // Perform the prank
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.prank(sender, origin);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank"
+ );
+
+ // Ensure we cleaned up correctly
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+
+ // Overwrite the prank
+ cheats.prank(sender);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", oldOrigin, "tx.origin invariant failed"
+ );
+
+ // Ensure we cleaned up correctly
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+ }
+
+ function testStartPrank0AfterPrank1(address sender, address origin) public {
+ // Perform the prank
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.startPrank(sender, origin);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank"
+ );
+
+ // Overwrite the prank
+ cheats.startPrank(sender);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", oldOrigin, "tx.origin invariant failed"
+ );
+
+ cheats.stopPrank();
+ // Ensure we cleaned up correctly after stopping the prank
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+ }
+
+ function testStartPrank1AfterStartPrank0(address sender, address origin) public {
+ // Perform the prank
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.startPrank(sender);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was set during prank incorrectly"
+ );
+
+ // Ensure prank is still up as startPrank covers multiple calls
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was cleaned up incorrectly", oldOrigin, "tx.origin invariant failed"
+ );
+
+ // Overwrite the prank
+ cheats.startPrank(sender, origin);
+ victim.assertCallerAndOrigin(sender, "msg.sender was not set during prank", origin, "tx.origin was not set");
+
+ // Ensure prank is still up as startPrank covers multiple calls
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was cleaned up incorrectly", origin, "tx.origin invariant failed"
+ );
+
+ cheats.stopPrank();
+ // Ensure everything is back to normal after stopPrank
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+ }
+
+ function testFailOverwriteUnusedPrank(address sender, address origin) public {
+ // Set the prank, but not use it
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.startPrank(sender, origin);
+ // try to overwrite the prank. This should fail.
+ cheats.startPrank(address(this), origin);
+ }
+
+ function testFailOverwriteUnusedPrankAfterSuccessfulPrank(address sender, address origin) public {
+ // Set the prank, but not use it
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.startPrank(sender, origin);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", origin, "tx.origin was set during prank incorrectly"
+ );
+ cheats.startPrank(address(this), origin);
+ // try to overwrite the prank. This should fail.
+ cheats.startPrank(sender, origin);
+ }
+
+ function testStartPrank0AfterStartPrank1(address sender, address origin) public {
+ // Perform the prank
+ address oldOrigin = tx.origin;
+ Victim victim = new Victim();
+ cheats.startPrank(sender, origin);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank"
+ );
+
+ // Ensure prank is still ongoing as we haven't called stopPrank
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was cleaned up incorrectly", origin, "tx.origin was cleaned up incorrectly"
+ );
+
+ // Overwrite the prank
+ cheats.startPrank(sender);
+ victim.assertCallerAndOrigin(
+ sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was not reset correctly"
+ );
+
+ cheats.stopPrank();
+ // Ensure we cleaned up correctly
+ victim.assertCallerAndOrigin(
+ address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed"
+ );
+ }
+
function testPrankConstructorSender(address sender) public {
cheats.prank(sender);
ConstructorVictim victim = new ConstructorVictim(