Advantage of prioritized expereience replay on the blind cliffwalk problem as shown by Schaul et al., 2015
NOTE : Couldn't confirm results. Raised them as issues in this readme. Any help regarding these issues will be greatly appreciated
All possible sequences of actions are generated and the sequences are run in a random order to obtain all possible transitions. Note that there will be repetition of the right transitions i.e. the transitions from 0th state to 1st state will be repeated N times (N = number of states), 1st to 2nd will be repeated N-1 times and so on. This will form our replay memory. In total, there will be (N*(N+1)/2 + N) transitions in the list.
Issue - The paper states that the number of sequences of actions should be 2^N. But I could only find the one sequence of right actions and N other sequences that terminate by the wrong action and the number of transitions in the replay memory to be (N(N+1)/2 + N)
The ground truth Q-table is generated by randomly sampling transitions a large number of times.
Randomly sampling transition from the transitions
A list of unique transitions is generated for the oracle. The oracle chooses the transition that leads to the lowest global loss. At each update step: For each transition:
- A copy of the current Q-table is created.
- The transition is run and the copied Q-table is updated
- MSE between the updated Q-table and true Q-table is calculated and stored under the name of the transition
The transition that leads to the minimum MSE is chosen
The agent greedily chooses the transition that generated the highest TD error the last time it was run. A binary heap (I used the heapq library that exists by default) structure is used to store all TD errors, update TD errors and to pick maximum TD error. To begin with all transitions are initialized with large TD errors, so that each transition is selected once in the beginning. After running each transition the TD error is updated along with the heap.
Each transition has a probability of p = TD_error + epsilon of being selected i.e. the agent chooses stochastically. The transition is selected using a sum tree structure that I created as a class scipt. Check the sum tree repository for more information.
The performance of the agents are compared based on how well they reduce the MSE (global loss). The MSE was averaged over 10 episodes for each agent. For N > 500, 5 episodes was used and the oracle was dropped since the execution was too slow.
- The Oracle performs better than the other agents in lower state numbers but the performance difference diminshes at larger N. Infact, the STPD agent performs better than the Oracle at large N values.
Issue - Even for larger number of states the Oracles performance is comparable to TD and SPTD agents. But the paper uses the Oracles performance as the baseline, since it outperforms other agents by a hughe margin at higher N.
-
At smaller N, the Random agent performs as well as the other agents. But as N increases the Random agents performance worsens in relation to the other agents.
-
The STPD agent gains some advantage over the TD agent as N increases.