This is the code repository associated with Zitovsky et. al. 2023 (see References section at the end of this document). For the remainder of this document, whenever we refer to "our" paper, we are referring to Zitovsky et. al. 2023, and whenever we refer to "our" results, tables and figures, we are referring to the results, tables and figures from Zitovsky et. al. 2023. If any part of this repository is used or referenced, please cite the associated paper. The main purpose of these scripts are to make the results of our paper reproducible. However, the scripts for our Atari experiments will also enable future researchers to conduct OMS experiments on the same or different Atari environments more easily and with greater computational efficiency, which is of independent interest and utility. The code base and documentation is a work-in-progress and will be updated over the next few months. The corresponding author of this repository is Josh Zitovsky and can be reached at joshz@live.unc.edu for any questions.
The scripts needed to reproduce our toy environment results can be found in the toy_scripts
directory of this repository. The runExperiment.R
script takes a single command-line argument representing the generative model parameter x
(see Appendix section C.1 of our paper for more details). For example, running the command Rscript runExperiment.R 0.5
runs the toy experiment discussed in section 6.1 and Appendix section C.1 of our paper while setting x=0.5 for the generative model, including simulating the training and validation datasets, estimating the Q-functions, evaluating their performance using online interactions, running SBV, running EMSBE and estimating the true MSBE, and will create two files storing relevant information about the experiments. After running the experiment over plotMSBE.R
and plotNoise.R
scripts to reproduce Figures 1 and 2 of our paper and plotMSBE_extra.R
to reproduce Figure G.1. The RFunctions.R
file contains helper functions used by runExperiment.R
to simulate the data, process the data and estimate the Q-functions, and may be helpful for future researchers that wish to conduct different OMS experiments on the same environment. Our documentation will be extended to include a thorough description of how these functions work in the near future. Please email the corresponding author if you are a researcher who needs to use these functions sooner.
The scripts needed to reproduce our mHealth environment results can be found in the mHealth_scripts
directory of this repository. The runExperiment.R
script takes a single command-line argument representing the seed to set when simulating the data. For example, running the command Rscript runExperiment.R 42
runs the mHealth experiment discussed in section 6.2 and Appendix section C.3 of our paper while setting the seed to 42 when simulating the dataset, including fitting the Q-functions, evaluating their policy value using online interaction, and running SBV, EMSBE and WIS, and will create two files storing relevant information about the experiment. To reproduce Figure 3 of our paper, we need to run experiments over seeds 42-51. Moreover, running applyFQE1.R
and applyFQE2.R
will apply FQE on the Q-functions estimated from the seed=42 dataset using the first and second training configuration seen in Table 1 of our paper respectively (it is necessary to run Rscript runExperiment.R 42
first). These scrips will then output csv
tables containing, among other information, the Q-function names (method
), the average reward of the policies (return
) and the FQE estimates (fqe
). The RFunctions.R
file contains helper functions used by runExperiment.R
to simulate and process the data, estimate the Q-functions and apply FQE and may be helpful for future researchers that wish to conduct different OMS experiments on the same environment. Our documentation will be extended to include a thorough description of how these functions work in the near future. Please email the corresponding author if you are a researcher who needs to use these functions sooner.
The scripts needed to reproduce the Bike environment results discussed in section 6.2 and Appendix section C.2 of our paper can be found in the bike_scripts
directory of this repository and follow a similar interface: the runExperiment.R
script takes a single command-line argument representing the seed to set when simulating the data, and to reproduce Figure 3 of our paper, we need to run experiments over seeds 42-51. The csv
files created by running applyFQE1.R
and applyFQE2.R
will have different information and column names, including the Q-function names (parmVec
), the average survival of the policies (survVec
) and the FQE estimates (fqeVec
), and apply the FQE configurations in the third and fourth rows of Table 1 of our paper.
Once we run experiments over seeds 42-51 for mHealth and Bike, we can then call plotBars.R
to reproduce Figures 3 and G.2 of our paper, and plotOracle.R
to reproduce G.3.
The scripts needed to reproduce our Atari environment results can be found in the atari_scripts
directory of this repository. Running these experiments requires more steps and computation than those of previous environments.
The processData.sh
script takes two arguments (referenced as arg1
and arg2
) as inputs, where arg1
is an Atari game environment (e.g. Pong
or SpaceInvaders
) and arg2
as integer between 1 and 5. arg1
and arg2
must be chosen such that the Google cloud bucket gs://atari-replay-datasets/dqn/arg1/arg2/replay_logs
exists and contains no missing files (i.e. 51 million states, actions, rewards and terminal statuses sharded across 51 files should be present, among other necessary metadata). Here gs://atari-replay-datasets
is the official bucket associated with the DQN-Replay Datasets (Agarwal, Schuurmans and Norouzi 2020). When we began our research, certain subdirectories such as dqn/Asterix/1
and dqn/Breakout/4
had less than 51 files for the states, actions, rewards and terminal statuses, and if our script is used for these values of arg1
and arg2
then it will crash. However, modifying the script to be able to handle arbitrary numbers of files would not be too difficult if neccesary. The script will load the requested data from gs://atari-replay-datasets/
and perform preprocessing steps to make the data in a format more amenable to DQN, SBV and other Q-learning and OMS training algorithms. Specifically, the script:
- creates a new directory
[arg1][arg2]
. All future steps will take place within this directory. - loads all the data from
gs://atari-replay-datasets/dqn/arg1/arg2/replay_logs
and stores it under alogs
directory - copies the scripts from the
atari_scripts
directory - Uniformly subsamples 20% of the data for training and 5% for validation, with
train_size
andval_size
transitions. Episodes used for training are different than the episodes used for validation. - Creates a
data
directory with subdirectoriestrain
,train2
andvalidate
. Thetrain/replay_logs
subdirectory contains the following files:obsers.npy
, a$\text{train\_size}\times 84\times 84$ array storing thetrain_size
observed frames for training;actions.npy
, a length$\text{train\_size}$ vector of sparse-encoded observed actions for training;rewards.npy
, a length$\text{train\_size}$ vector of observed immediate rewards for training;terminals.npy
, a$\text{train\_size}$ vector of observed terminal statuses associated with the observations inobsers.npy
;episodes.npy
, a$\text{train\_size}$ vector containing episode IDs for the training transitions corresponding to the four aforementioned files; anddeletes.npy
, a preprocessing vector used to filter out or give weight zero to invalid states in our scripts (it indicates the time steps at which states are invalid, or at which states include a stack of observations corresponding to different episodes). - Creates a
data/validate/replay_logs
subdirectory containing similar files as those intrain_replay_logs
, except that they are used for validation. - Creates a
data/train2/replay_logs
subdirectory that shards the training data into multiple compressed files to be used by thebatch_rl
GitHub repository (Agarwal, Schuurmans and Norouzi 2020) - Creates a
shards
directory that shards the training data into multiple files to reduce memory consumption. This directory is used to run SBV in a more memory-efficient manner on three of four games in our Atari experiments.
Once processData.sh
finishes writing the data/train2/replay_logs
files, we can run the runDqn.sh
script to run the simple DQN configuration discussed in section 6.3 and Appendix section D.1 of our paper for 2000 iterations. The script uses modifications of the batch_rl
repository (Agarwal, Schuurmans and Norouzi 2020). It is worth noting that a single iteration in our paper refers to 40 iterations here. Relevant information will be saved in the tmp/dqn
directories. Evaluation is done over a larger number of episodes every 40 iterations (for other iterations, evaluation is only over a single episode). The evaluated returns after each iteration is logged in the tmp/dqn/logs
directory and can be accessed using the dopamine.colab.utils.read_experiment
function similar to logs saved from batch_rl
and the dopamine
library (Castro et. al. 2018). The architecture of the Q-Network is saved as tmp/dqn/testmodel
and can be accessed using the keras.models.load_model
function. The weights of the trained Q-Network is saved after 40 iterations in tmp/dqn/checkpoints
as 40.npy
, 80.npy
etc. These weights can be loaded using the numpy.load
function and set to an existing keras
model object by using its set_weights
function attribute. runDqn_deep.sh
is similar, except that it runs the deep DQN configuration discussed in section 6.3 and Appendix section D.1 of of our paper and relevant data is stored in tmp/dqn_deep
. We plan to extend our documentation to explain how to manipulate these scripts to change the Q-Network architecture, optimizer and other hyperparameters in the near future. Please email the corresponding author if you are a researcher who needs to use this information sooner.
Once our DQN agents are fully trained using runDqn.sh
and rundqn_deep.sh
, the calcTargets.sh
script performs additional calculations and preprocessing neccesary to use our EMSBE, SBV, WIS and FQE scripts. The script takes as argument (arg1
) the DQN run (either dqn
or dqn_deep
) and creates targetsT
and targetsV
directories (unless they already exist). For iteration targetsT
, the script will create the following files from the training set: y_[arg1]_[m]
, a $\text{train\size}\times 2$ matrix with the first column containing the backup targets $r_t+(1-d{t+1})\gamma\max_{a'}Q_{m}(s_{t+1},a')$ where q_vals_[arg1]_[m]
, a train_size
length vector containing action-value estimates q_policies_[arg1]_[m]
, a train_size
length vector containing policy estimates targetsV
, similar files will be created, except that the files are created from the validation set instead. We can also calculate such files for a single iteration and for a single data partition (training or validation) using the calcTargets.py
script. For example, python calcTargets.py train dqn 40
will create the files within targetsT
associated with run tmp/dqn
and iteration 40, while python calcTargets.py validate dqn 40
will create similar files within targetsV
. Finally, calcTargets.py
could in theory be applied to other Q-Networks aside from those trained by runDqn.sh
and rundqn_deep.sh
: we only require that the tmp/[run]
directory exists, contains a saved Keras model titled testmodel
and contains a checkpoint
subdirectory with the Q-Network weights saved as .npy
files and titled by iteration number.
Once calcTargets.sh
is run, we can then use our scripts to implement OMS algorithms SBV, WIS, FQE and EMSBE as the Atari experiments of our paper (section 6.3). EMSBE can be implemented by running genTargetDict.py
: This script will calculate validation EMSBE for the Q-networks saved in tmp/dqn
, tmp/dqn_deep
or both (if they exist), for every 40 iterations from 40 to 2000, and write the output to a target_dict.csv
file, as implemented in the Atari experiments of our paper. It is easy enough to manipulate the script to calculate the EMSBE for different Q-functions of for different runs: All we require is that the files y_[run]_[iter]
and q_vals_[run]_[iter].npy
exist for given run run
and given iterations iter
, as calculated by calcTargets.py
.
WIS can be implemented by running the script trainPropensity.py
within the Asterix
, Breakout
, Pong
or Seaquest
subdirectories. This script will calculate a propensity network Pong
file is the most different and uses a much simpler architecture. Remaining files are more similar to each to each other, with exception to minor hyperparameter changes such as the learning rate. The trainPropensity.py
script was designed to maximize computational and memory efficiency per training epoch subject to loading the entire data into RAM at once. Greater memory efficiency can be obtained by avoiding loading all the data into memory at once and instead load partial data periodically from data/train
. This is how we implemented SBV and modifying our trainPropensity.py
file to do this shouldn't be too difficult. The network will save files propensity.csv
and propensity.hdf5
to the models
directory, containing training progress and saved weights of the best iteration based on validation log-likelihood. It also saved a val_size
length vector propensities.npy
to targetsV
containing the estimated propensities genWISDict.py
to calculate WIS for the Q-networks saved in tmp/dqn
, tmp/dqn_deep
or both (if they exist), for every 40 iterations from 40 to 2000, and write the resulting output to a wis_dict.csv
file. It is easy enough to manipulate the script to calculate the WIS for different Q-functions or for different runs and iterations: All we require is that the files q_policies_[run]_[iter].npy
, exist for given run run
and given iterations iter
, as calculated by calcTargets.py
.
FQE can be implemented by running the script fitQEval.py
. This script will run FQE using a similar training configuration to runDqn.sh
, i.e. the DQN (Adam) configuration of Agarwal, Schuurmans and Norouzi 2020 or the shallow DQN configuration of our paper, minus the obvious modifications to perform policy evaluation instead of optimal Q-function estimation. The script takes as command-line arguments run
and iter
and prev_iter
: run
and iter
should be chosen such that the file targetsT/q_policies_[run]_[iter].npy
exists as calculated by calcTargets.py
. The script then runs FQE for 100 epochs, saving the training progress as models/q_model_[run]_[iter].csv
, the trained weights as models/q_model_[run]_[iter].hdf5
and the estimated expected returns (based on the trained FQE model) as targetsV/value_[run]_[iter].npy
. We used as prev_iter
iter
-40 assuming the FQE model for iter
-40 was already finished training. If models/q_model_[run]_[prev_iter].hdf5
exists, the model will use these weights as the starting value and halve the number of training epochs from 100 to 50. Otherwise, it will just use random starting weights and run for 100 epochs. The fitQEval.py
script was designed to maximize computational and memory efficiency subject to loading the entire data into RAM at once. Greater memory efficiency can be obtained by avoiding loading all the data into memory at once and instead load partial data periodically from data/train
. This is how we implemented SBV and modifying our fitQEval.py
file to do this shouldn't be too difficult. Once the desired estimated expected returns are saved, we can then use genFQEDict.py
to calculate FQE for the Q-networks saved in tmp/dqn
, tmp/dqn_deep
or both (if they exist), for every 40 iterations from 40 to 2000, and write the resulting output to a fqe_dict.csv
file. It is easy enough to manipulate the script to calculate the FQE for different Q-functions or for different runs and iterations: All we require is that the files y_[run]_[iter]
and q_policies_[run]_[iter]
exist for given run run
and given iterations iter
, as calculated by calcTargets.sh
.
Finally, SBV can be implemented by running the script runSBV.py
within the Asterix
, Breakout
, Pong
or Seaquest
subdirectories. These scripts were used to calculate Bellman networks for the Asterix, Breakout, Pong and Seaquest experiments, respectively, as discussed in section 6.3 and Appendix section D.2 of our paper. The Bellman network training configuration is different between these environments, hence the different files. The Pong
file is the most different because it uses a much simpler architecture. Moreover, for Pong
, Asterix
and Seaquest
, we avoided loading all data into memory at once, instead loading in only a few files from shards/train
at a time to improve memory efficiency. For Breakout
, we found this strategy led to moderate increase in validation error (around 5%) and thus loaded the full data into memory instead. However, a similar memory-efficient implementation for Breakout
can be executed with the runSBV_memory.py
script. The script takes as arguments the run run
and an arbitrary list of iterations iter1 iter2...
. For example, running python runSBV.py ddqn_Srch 40 80 120
will train three Bellman network based on the Q-Networks of the dqn_deep
run after iteration 40, 80 and 120. Moreover, for Pong
, Seaquest
and Breakout
, the Bellman network for estimating dqn_deep
iteration 40 will be used as starting value to estimate the Bellman network for estimating dqn_deep
iteration 80, and similarly for dqn_deep
iteration 120, to improve computational efficiency. To avoid overfitting, we avoided using starting weights from a previous iteration more than twice to avoid for Breakout and Seaquest and more than four times for Pong. So for instance, we would run python runSBV.py ddqn_Srch 40 80 120
and then python runSBV.py ddqn_Srch 160 200 240
to train a Bellman network for iterations 40-240 for Seaquest and Breakout. For Asterix
, using any such warm starts led to moderate validation error increase, and thus running python runSBV.py ddqn_Srch 40 80 120 160 200 240
using the Asterix
script will train Bellman networks for iterations 40-240 while using random starting weights for all six Bellman network.
When running python runSBV.py run iter1 iter2 ...
, files models/run_iter1.hdf5
, models/run_iters2.hdf5
etc will be written to save the trained weights after the best iteration and models/[run]_[iter1].csv
, models/[run]_[iters2].csv
etc will save information about the training progress. Finally, files targetsV/backups_[run]_[iter1].npy
and targetsV/backups_[run]_[iter2].npy
will be saved storing the estimated Bellman backups models
directory. Once the desired estimated backups are saved, we can then use genSBVDict.py
to calculate FQE for the Q-networks saved in tmp/dqn
, tmp/dqn_deep
or both (if they exist), and write the resulting output to a sbv_dict.csv
file. It is easy enough to manipulate the runSBV.py
and genSBVDict.py
scripts to run SBV on different Q-functions or for different runs and iterations: All we require is that the files y_[run]_[iter]
and q_vals_[run]_[iter]
exist for given run run
and given iterations iter
, as calculated by calcTargets.sh
.
To reproduce our tables and figures, we first use processData.sh
to process data for the following subdirectories located in the bucket gs://atari-replay-datasets/dqn
: Pong/1
, Pong/2
, Pong/4
, Seaquest/1
, Seaquest/2
, Seaquest/4
, Breakout/1
, Breakout/2
, Breakout/3
, Asterix/2
, Asterix/3
and Asterix/5
. We didn't use datasets 1,2,4 for Breakout and Asterix because Asterix/1
, Breakout/4
and Breakout/5
were all missing files when we first started our research. We then run sh runDqn.sh
on Pong and both sh runDqn.sh
and sh rundqn_deep.sh
the remaining three games. We then run sh calcTargets.sh dqn
on Pong and both sh calcTargets.sh dqn
and sh calcTargets.sh dqn_deep
on the remaining three environments. We then use trainPropensity.py
and runSBV.py
to train the neccesary propensity and Bellman networks. We then genDicts.sh
to make the target_dict.csv
, wis_dict.csv
and sbv_dict.csv
files. Finally, the associated Atari directories (e.g. Breakout1
, Seaquest4
, Asterix5
) containing the files target_dict.csv
, wis_dict.csv
and sbv_dict.csv
should all be moved within the same parent directory. From this parent directory, we can then run plotGames2.R
and plotGames.R
to generate Figures 4 and G.5 from our paper, respectively. We can also run atariOracle.sh
to generate Figure G.4. Note that due to stochasticity of deep learning training, the observed figures and numbers may be somewhat different than those seen in our paper, but the general conclusion of SBV outperforming EMSBE, WIS and FQE should still be apparent.
To make reproducability easier, we also store the Q-functions analyzed in section 6.3 of our paper as well as thewis_dict.csv
, sbv_dict.csv
and target_dict.csv
files used to generate Table 2 and Figures 4, G.4 and G.5. These objects can be found in the atari_objects
directory of this repository. The FQE Results for Seaquest discussed in section 6.3 of our paper can also be found in the fqe_dict.csv
file located in atari_objects/Seaquest1
.
[1] Joshua P Zitovsky, Daniel de Marchi, Rishabh Agarwal and Michael R Kosorok. "Revisiting Bellman Errors for Offline Model Selection". arXiv abs/2302.00141.
[2] Rishabh Agarwal, Dale Schuurmans, and Mohammad Norouzi. "An Optimistic Perspective on Offline Reinforcement Learning". ICML, 2020
[3] Pablo S. Castro, Subhodeep Moitra, Carles Gelada, Saurabh Kumar and Marc G. Bellemare. "Dopamine: A Research Framework for Deep Reinforcement Learning". arXiv, abs/1812.06110, 2018.