diff --git a/crates/chain-state/src/memory_overlay.rs b/crates/chain-state/src/memory_overlay.rs
index 4ba10d34a4a8..bb281a3f25ab 100644
--- a/crates/chain-state/src/memory_overlay.rs
+++ b/crates/chain-state/src/memory_overlay.rs
@@ -1,7 +1,7 @@
 use super::ExecutedBlock;
 use reth_errors::ProviderResult;
 use reth_primitives::{
-    Account, Address, BlockNumber, Bytecode, Bytes, StorageKey, StorageValue, B256,
+    keccak256, Account, Address, BlockNumber, Bytecode, Bytes, StorageKey, StorageValue, B256,
 };
 use reth_storage_api::{
     AccountReader, BlockHashReader, StateProofProvider, StateProvider, StateProviderBox,
@@ -43,18 +43,16 @@ impl MemoryOverlayStateProvider {
     }
 
     /// Return lazy-loaded trie state aggregated from in-memory blocks.
-    fn trie_state(&self) -> MemoryOverlayTrieState {
-        self.trie_state
-            .get_or_init(|| {
-                let mut hashed_state = HashedPostState::default();
-                let mut trie_nodes = TrieUpdates::default();
-                for block in self.in_memory.iter().rev() {
-                    hashed_state.extend_ref(block.hashed_state.as_ref());
-                    trie_nodes.extend_ref(block.trie.as_ref());
-                }
-                MemoryOverlayTrieState { trie_nodes, hashed_state }
-            })
-            .clone()
+    fn trie_state(&self) -> &MemoryOverlayTrieState {
+        self.trie_state.get_or_init(|| {
+            let mut hashed_state = HashedPostState::default();
+            let mut trie_nodes = TrieUpdates::default();
+            for block in self.in_memory.iter().rev() {
+                hashed_state.extend_ref(block.hashed_state.as_ref());
+                trie_nodes.extend_ref(block.trie.as_ref());
+            }
+            MemoryOverlayTrieState { trie_nodes, hashed_state }
+        })
     }
 }
 
@@ -115,7 +113,7 @@ impl StateRootProvider for MemoryOverlayStateProvider {
         state: HashedPostState,
         prefix_sets: TriePrefixSetsMut,
     ) -> ProviderResult<B256> {
-        let MemoryOverlayTrieState { mut trie_nodes, mut hashed_state } = self.trie_state();
+        let MemoryOverlayTrieState { mut trie_nodes, mut hashed_state } = self.trie_state().clone();
         trie_nodes.extend(nodes);
         hashed_state.extend(state);
         self.historical.hashed_state_root_from_nodes(trie_nodes, hashed_state, prefix_sets)
@@ -139,7 +137,7 @@ impl StateRootProvider for MemoryOverlayStateProvider {
         state: HashedPostState,
         prefix_sets: TriePrefixSetsMut,
     ) -> ProviderResult<(B256, TrieUpdates)> {
-        let MemoryOverlayTrieState { mut trie_nodes, mut hashed_state } = self.trie_state();
+        let MemoryOverlayTrieState { mut trie_nodes, mut hashed_state } = self.trie_state().clone();
         trie_nodes.extend(nodes);
         hashed_state.extend(state);
         self.historical.hashed_state_root_from_nodes_with_updates(
@@ -155,18 +153,27 @@ impl StateRootProvider for MemoryOverlayStateProvider {
         address: Address,
         storage: HashedStorage,
     ) -> ProviderResult<B256> {
-        self.historical.hashed_storage_root(address, storage)
+        let mut hashed_storage = self
+            .trie_state()
+            .hashed_state
+            .storages
+            .get(&keccak256(address))
+            .cloned()
+            .unwrap_or_default();
+        hashed_storage.extend(&storage);
+        self.historical.hashed_storage_root(address, hashed_storage)
     }
 }
 
 impl StateProofProvider for MemoryOverlayStateProvider {
+    // TODO: Currently this does not reuse available in-memory trie nodes.
     fn hashed_proof(
         &self,
         state: HashedPostState,
         address: Address,
         slots: &[B256],
     ) -> ProviderResult<AccountProof> {
-        let MemoryOverlayTrieState { mut hashed_state, .. } = self.trie_state();
+        let mut hashed_state = self.trie_state().hashed_state.clone();
         hashed_state.extend(state);
         self.historical.hashed_proof(hashed_state, address, slots)
     }
@@ -177,7 +184,7 @@ impl StateProofProvider for MemoryOverlayStateProvider {
         overlay: HashedPostState,
         target: HashedPostState,
     ) -> ProviderResult<HashMap<B256, Bytes>> {
-        let MemoryOverlayTrieState { mut hashed_state, .. } = self.trie_state();
+        let mut hashed_state = self.trie_state().hashed_state.clone();
         hashed_state.extend(overlay);
         self.historical.witness(hashed_state, target)
     }