diff --git a/evm/src/executor/inspector/cheatcodes/env.rs b/evm/src/executor/inspector/cheatcodes/env.rs index da688b8a6edf0..9c9609f724def 100644 --- a/evm/src/executor/inspector/cheatcodes/env.rs +++ b/evm/src/executor/inspector/cheatcodes/env.rs @@ -51,6 +51,39 @@ pub struct Prank { pub depth: u64, /// Whether the prank stops by itself after the next call pub single_call: bool, + /// Whether the prank has been used yet (false if unused) + pub used: bool, +} + +impl Prank { + pub fn new( + prank_caller: Address, + prank_origin: Address, + new_caller: Address, + new_origin: Option
, + depth: u64, + single_call: bool, + ) -> Prank { + Prank { + prank_caller, + prank_origin, + new_caller, + new_origin, + depth, + single_call, + used: false, + } + } + + /// Apply the prank by setting `used` to true iff it is false + /// Only returns self in the case it is updated (first application) + pub fn first_time_applied(&self) -> Option { + if self.used { + None + } else { + Some(Prank { used: true, ..self.clone() }) + } + } } /// Sets up broadcasting from a script using `origin` as the sender @@ -102,14 +135,18 @@ fn prank( depth: u64, single_call: bool, ) -> Result { + let prank = Prank::new(prank_caller, prank_origin, new_caller, new_origin, depth, single_call); + + if let Some(Prank { used, .. }) = state.prank { + ensure!(used, "You cannot overwrite `prank` until it is applied at least once"); + } + ensure!( state.broadcast.is_none(), - "You cannot `prank` for a broadcasted transaction. \ + "You cannot `prank` for a broadcasted transaction.\ Pass the desired tx.origin into the broadcast cheatcode call" ); - ensure!(state.prank.is_none(), "You have an active prank already."); - let prank = Prank { prank_caller, prank_origin, new_caller, new_origin, depth, single_call }; state.prank = Some(prank); Ok(Bytes::new()) } diff --git a/evm/src/executor/inspector/cheatcodes/mod.rs b/evm/src/executor/inspector/cheatcodes/mod.rs index ce824d935a09a..345e320f1dedc 100644 --- a/evm/src/executor/inspector/cheatcodes/mod.rs +++ b/evm/src/executor/inspector/cheatcodes/mod.rs @@ -607,15 +607,25 @@ where if data.journaled_state.depth() >= prank.depth && call.context.caller == h160_to_b160(prank.prank_caller) { + let mut prank_applied = false; // At the target depth we set `msg.sender` if data.journaled_state.depth() == prank.depth { call.context.caller = h160_to_b160(prank.new_caller); call.transfer.source = h160_to_b160(prank.new_caller); + prank_applied = true; } // At the target depth, or deeper, we set `tx.origin` if let Some(new_origin) = prank.new_origin { data.env.tx.caller = h160_to_b160(new_origin); + prank_applied = true; + } + + // If prank applied for first time, then update + if prank_applied { + if let Some(applied_prank) = prank.first_time_applied() { + self.prank = Some(applied_prank); + } } } } diff --git a/testdata/cheats/Prank.t.sol b/testdata/cheats/Prank.t.sol index fccccff7c9697..0dddb2ffe702d 100644 --- a/testdata/cheats/Prank.t.sol +++ b/testdata/cheats/Prank.t.sol @@ -118,6 +118,159 @@ contract PrankTest is DSTest { ); } + function testPrank1AfterPrank0(address sender, address origin) public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.prank(sender); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + + // Overwrite the prank + cheats.prank(sender, origin); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", origin, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testPrank0AfterPrank1(address sender, address origin) public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.prank(sender, origin); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + + // Overwrite the prank + cheats.prank(sender); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", oldOrigin, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testStartPrank0AfterPrank1(address sender, address origin) public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.startPrank(sender, origin); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank" + ); + + // Overwrite the prank + cheats.startPrank(sender); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", oldOrigin, "tx.origin invariant failed" + ); + + cheats.stopPrank(); + // Ensure we cleaned up correctly after stopping the prank + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testStartPrank1AfterStartPrank0(address sender, address origin) public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.startPrank(sender); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was set during prank incorrectly" + ); + + // Ensure prank is still up as startPrank covers multiple calls + victim.assertCallerAndOrigin( + sender, "msg.sender was cleaned up incorrectly", oldOrigin, "tx.origin invariant failed" + ); + + // Overwrite the prank + cheats.startPrank(sender, origin); + victim.assertCallerAndOrigin(sender, "msg.sender was not set during prank", origin, "tx.origin was not set"); + + // Ensure prank is still up as startPrank covers multiple calls + victim.assertCallerAndOrigin( + sender, "msg.sender was cleaned up incorrectly", origin, "tx.origin invariant failed" + ); + + cheats.stopPrank(); + // Ensure everything is back to normal after stopPrank + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testFailOverwriteUnusedPrank(address sender, address origin) public { + // Set the prank, but not use it + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.startPrank(sender, origin); + // try to overwrite the prank. This should fail. + cheats.startPrank(address(this), origin); + } + + function testFailOverwriteUnusedPrankAfterSuccessfulPrank(address sender, address origin) public { + // Set the prank, but not use it + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.startPrank(sender, origin); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", origin, "tx.origin was set during prank incorrectly" + ); + cheats.startPrank(address(this), origin); + // try to overwrite the prank. This should fail. + cheats.startPrank(sender, origin); + } + + function testStartPrank0AfterStartPrank1(address sender, address origin) public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + cheats.startPrank(sender, origin); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", origin, "tx.origin was not set during prank" + ); + + // Ensure prank is still ongoing as we haven't called stopPrank + victim.assertCallerAndOrigin( + sender, "msg.sender was cleaned up incorrectly", origin, "tx.origin was cleaned up incorrectly" + ); + + // Overwrite the prank + cheats.startPrank(sender); + victim.assertCallerAndOrigin( + sender, "msg.sender was not set during prank", oldOrigin, "tx.origin was not reset correctly" + ); + + cheats.stopPrank(); + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + function testPrankConstructorSender(address sender) public { cheats.prank(sender); ConstructorVictim victim = new ConstructorVictim(