Skip to content

jzitovsky/SBV

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SBV

This is the code repository associated with Zitovsky et. al. 2023 (see References section at the end of this document). For the remainder of this document, whenever we refer to "our" paper, we are referring to Zitovsky et. al. 2023, and whenever we refer to "our" results, tables and figures, we are referring to the results, tables and figures from Zitovsky et. al. 2023. If any part of this repository is used or referenced, please cite the associated paper. The main purpose of these scripts are to make the results of our paper reproducible. However, the scripts for our Atari experiments will also enable future researchers to conduct OMS experiments on the same or different Atari environments more easily and with greater computational efficiency, which is of independent interest and utility. The code base and documentation is a work-in-progress and will be updated over the next few months. The corresponding author of this repository is Josh Zitovsky and can be reached at joshz@live.unc.edu for any questions.

Toy Experiments

The scripts needed to reproduce our toy environment results can be found in the toy_scripts directory of this repository. The runExperiment.R script takes a single command-line argument representing the generative model parameter x (see Appendix section C.1 of our paper for more details). For example, running the command Rscript runExperiment.R 0.5 runs the toy experiment discussed in section 6.1 and Appendix section C.1 of our paper while setting x=0.5 for the generative model, including simulating the training and validation datasets, estimating the Q-functions, evaluating their performance using online interactions, running SBV, running EMSBE and estimating the true MSBE, and will create two files storing relevant information about the experiments. After running the experiment over $x\in\{0.5,0.55,0.6,0.65,0.7,0.75\}$, we can use the plotMSBE.R and plotNoise.R scripts to reproduce Figures 1 and 2 of our paper and plotMSBE_extra.R to reproduce Figure G.1. The RFunctions.R file contains helper functions used by runExperiment.R to simulate the data, process the data and estimate the Q-functions, and may be helpful for future researchers that wish to conduct different OMS experiments on the same environment. Our documentation will be extended to include a thorough description of how these functions work in the near future. Please email the corresponding author if you are a researcher who needs to use these functions sooner.

mHealth and Bike Experiments

The scripts needed to reproduce our mHealth environment results can be found in the mHealth_scripts directory of this repository. The runExperiment.R script takes a single command-line argument representing the seed to set when simulating the data. For example, running the command Rscript runExperiment.R 42 runs the mHealth experiment discussed in section 6.2 and Appendix section C.3 of our paper while setting the seed to 42 when simulating the dataset, including fitting the Q-functions, evaluating their policy value using online interaction, and running SBV, EMSBE and WIS, and will create two files storing relevant information about the experiment. To reproduce Figure 3 of our paper, we need to run experiments over seeds 42-51. Moreover, running applyFQE1.R and applyFQE2.R will apply FQE on the Q-functions estimated from the seed=42 dataset using the first and second training configuration seen in Table 1 of our paper respectively (it is necessary to run Rscript runExperiment.R 42 first). These scrips will then output csv tables containing, among other information, the Q-function names (method), the average reward of the policies (return) and the FQE estimates (fqe). The RFunctions.R file contains helper functions used by runExperiment.R to simulate and process the data, estimate the Q-functions and apply FQE and may be helpful for future researchers that wish to conduct different OMS experiments on the same environment. Our documentation will be extended to include a thorough description of how these functions work in the near future. Please email the corresponding author if you are a researcher who needs to use these functions sooner.

The scripts needed to reproduce the Bike environment results discussed in section 6.2 and Appendix section C.2 of our paper can be found in the bike_scripts directory of this repository and follow a similar interface: the runExperiment.R script takes a single command-line argument representing the seed to set when simulating the data, and to reproduce Figure 3 of our paper, we need to run experiments over seeds 42-51. The csv files created by running applyFQE1.R and applyFQE2.R will have different information and column names, including the Q-function names (parmVec), the average survival of the policies (survVec) and the FQE estimates (fqeVec), and apply the FQE configurations in the third and fourth rows of Table 1 of our paper.

Once we run experiments over seeds 42-51 for mHealth and Bike, we can then call plotBars.R to reproduce Figures 3 and G.2 of our paper, and plotOracle.R to reproduce G.3.

Atari Experiments

The scripts needed to reproduce our Atari environment results can be found in the atari_scripts directory of this repository. Running these experiments requires more steps and computation than those of previous environments.

Preprocessing

The processData.sh script takes two arguments (referenced as arg1 and arg2) as inputs, where arg1 is an Atari game environment (e.g. Pong or SpaceInvaders) and arg2 as integer between 1 and 5. arg1 and arg2 must be chosen such that the Google cloud bucket gs://atari-replay-datasets/dqn/arg1/arg2/replay_logs exists and contains no missing files (i.e. 51 million states, actions, rewards and terminal statuses sharded across 51 files should be present, among other necessary metadata). Here gs://atari-replay-datasets is the official bucket associated with the DQN-Replay Datasets (Agarwal, Schuurmans and Norouzi 2020). When we began our research, certain subdirectories such as dqn/Asterix/1 and dqn/Breakout/4 had less than 51 files for the states, actions, rewards and terminal statuses, and if our script is used for these values of arg1 and arg2 then it will crash. However, modifying the script to be able to handle arbitrary numbers of files would not be too difficult if neccesary. The script will load the requested data from gs://atari-replay-datasets/ and perform preprocessing steps to make the data in a format more amenable to DQN, SBV and other Q-learning and OMS training algorithms. Specifically, the script:

  1. creates a new directory [arg1][arg2]. All future steps will take place within this directory.
  2. loads all the data from gs://atari-replay-datasets/dqn/arg1/arg2/replay_logs and stores it under a logs directory
  3. copies the scripts from the atari_scripts directory
  4. Uniformly subsamples 20% of the data for training and 5% for validation, with train_size and val_size transitions. Episodes used for training are different than the episodes used for validation.
  5. Creates a data directory with subdirectories train, train2 and validate. The train/replay_logs subdirectory contains the following files: obsers.npy, a $\text{train\_size}\times 84\times 84$ array storing the train_size observed frames for training; actions.npy, a length $\text{train\_size}$ vector of sparse-encoded observed actions for training; rewards.npy, a length $\text{train\_size}$ vector of observed immediate rewards for training; terminals.npy, a $\text{train\_size}$ vector of observed terminal statuses associated with the observations in obsers.npy; episodes.npy, a $\text{train\_size}$ vector containing episode IDs for the training transitions corresponding to the four aforementioned files; and deletes.npy, a preprocessing vector used to filter out or give weight zero to invalid states in our scripts (it indicates the time steps at which states are invalid, or at which states include a stack of observations corresponding to different episodes).
  6. Creates a data/validate/replay_logs subdirectory containing similar files as those in train_replay_logs, except that they are used for validation.
  7. Creates a data/train2/replay_logs subdirectory that shards the training data into multiple compressed files to be used by the batch_rl GitHub repository (Agarwal, Schuurmans and Norouzi 2020)
  8. Creates a shards directory that shards the training data into multiple files to reduce memory consumption. This directory is used to run SBV in a more memory-efficient manner on three of four games in our Atari experiments.

DQN and More Preprocessing

Once processData.sh finishes writing the data/train2/replay_logs files, we can run the runDqn.sh script to run the simple DQN configuration discussed in section 6.3 and Appendix section D.1 of our paper for 2000 iterations. The script uses modifications of the batch_rl repository (Agarwal, Schuurmans and Norouzi 2020). It is worth noting that a single iteration in our paper refers to 40 iterations here. Relevant information will be saved in the tmp/dqn directories. Evaluation is done over a larger number of episodes every 40 iterations (for other iterations, evaluation is only over a single episode). The evaluated returns after each iteration is logged in the tmp/dqn/logs directory and can be accessed using the dopamine.colab.utils.read_experiment function similar to logs saved from batch_rl and the dopamine library (Castro et. al. 2018). The architecture of the Q-Network is saved as tmp/dqn/testmodel and can be accessed using the keras.models.load_model function. The weights of the trained Q-Network is saved after 40 iterations in tmp/dqn/checkpoints as 40.npy, 80.npy etc. These weights can be loaded using the numpy.load function and set to an existing keras model object by using its set_weights function attribute. runDqn_deep.sh is similar, except that it runs the deep DQN configuration discussed in section 6.3 and Appendix section D.1 of of our paper and relevant data is stored in tmp/dqn_deep. We plan to extend our documentation to explain how to manipulate these scripts to change the Q-Network architecture, optimizer and other hyperparameters in the near future. Please email the corresponding author if you are a researcher who needs to use this information sooner.

Once our DQN agents are fully trained using runDqn.sh and rundqn_deep.sh, the calcTargets.sh script performs additional calculations and preprocessing neccesary to use our EMSBE, SBV, WIS and FQE scripts. The script takes as argument (arg1) the DQN run (either dqn or dqn_deep) and creates targetsT and targetsV directories (unless they already exist). For iteration $m\in\{40,80,...,2000\}$ and within targetsT, the script will create the following files from the training set: y_[arg1]_[m], a $\text{train\size}\times 2$ matrix with the first column containing the backup targets $r_t+(1-d{t+1})\gamma\max_{a'}Q_{m}(s_{t+1},a')$ where $d_t$ is the terminal status at time step $t$ and $Q_m$ is the Q-function after iteration $m$, and with the second column containing training weights $w_t$ which is zero whenever $s_t$ is zero or contains observations/frames from different episodes; q_vals_[arg1]_[m], a train_size length vector containing action-value estimates $Q_m(s_t,a_t)$; and q_policies_[arg1]_[m], a train_size length vector containing policy estimates $\pi_m(s_t)=\text{argmax}_aQ_m(s_t,a)$. Within targetsV, similar files will be created, except that the files are created from the validation set instead. We can also calculate such files for a single iteration and for a single data partition (training or validation) using the calcTargets.py script. For example, python calcTargets.py train dqn 40 will create the files within targetsT associated with run tmp/dqn and iteration 40, while python calcTargets.py validate dqn 40 will create similar files within targetsV. Finally, calcTargets.py could in theory be applied to other Q-Networks aside from those trained by runDqn.sh and rundqn_deep.sh: we only require that the tmp/[run] directory exists, contains a saved Keras model titled testmodel and contains a checkpoint subdirectory with the Q-Network weights saved as .npy files and titled by iteration number.

EMSBE

Once calcTargets.sh is run, we can then use our scripts to implement OMS algorithms SBV, WIS, FQE and EMSBE as the Atari experiments of our paper (section 6.3). EMSBE can be implemented by running genTargetDict.py: This script will calculate validation EMSBE for the Q-networks saved in tmp/dqn, tmp/dqn_deep or both (if they exist), for every 40 iterations from 40 to 2000, and write the output to a target_dict.csv file, as implemented in the Atari experiments of our paper. It is easy enough to manipulate the script to calculate the EMSBE for different Q-functions of for different runs: All we require is that the files y_[run]_[iter] and q_vals_[run]_[iter].npy exist for given run run and given iterations iter, as calculated by calcTargets.py.

WIS

WIS can be implemented by running the script trainPropensity.py within the Asterix, Breakout, Pong or Seaquest subdirectories. This script will calculate a propensity network $\mu_\beta$ (or estimated behavioral policy) as discussed in Appendix section D.3 of our paper. These scripts were used to calculate propensity networks for the Asterix, Breakout, Pong and Seaquest experiments, respectively, whose results are discussed in section 6.3 of our paper. The propensity network training configuration is different between these environments, hence the different files. The Pong file is the most different and uses a much simpler architecture. Remaining files are more similar to each to each other, with exception to minor hyperparameter changes such as the learning rate. The trainPropensity.py script was designed to maximize computational and memory efficiency per training epoch subject to loading the entire data into RAM at once. Greater memory efficiency can be obtained by avoiding loading all the data into memory at once and instead load partial data periodically from data/train. This is how we implemented SBV and modifying our trainPropensity.py file to do this shouldn't be too difficult. The network will save files propensity.csv and propensity.hdf5 to the models directory, containing training progress and saved weights of the best iteration based on validation log-likelihood. It also saved a val_size length vector propensities.npy to targetsV containing the estimated propensities $\mu_\beta(s_t,a_t)$ for validation state-action pairs $(s_t,a_t)$. Once these estimated validation propensities are saved, we can use genWISDict.py to calculate WIS for the Q-networks saved in tmp/dqn, tmp/dqn_deep or both (if they exist), for every 40 iterations from 40 to 2000, and write the resulting output to a wis_dict.csv file. It is easy enough to manipulate the script to calculate the WIS for different Q-functions or for different runs and iterations: All we require is that the files q_policies_[run]_[iter].npy, exist for given run run and given iterations iter, as calculated by calcTargets.py.

FQE

FQE can be implemented by running the script fitQEval.py. This script will run FQE using a similar training configuration to runDqn.sh, i.e. the DQN (Adam) configuration of Agarwal, Schuurmans and Norouzi 2020 or the shallow DQN configuration of our paper, minus the obvious modifications to perform policy evaluation instead of optimal Q-function estimation. The script takes as command-line arguments run and iter and prev_iter: run and iter should be chosen such that the file targetsT/q_policies_[run]_[iter].npy exists as calculated by calcTargets.py. The script then runs FQE for 100 epochs, saving the training progress as models/q_model_[run]_[iter].csv, the trained weights as models/q_model_[run]_[iter].hdf5 and the estimated expected returns (based on the trained FQE model) as targetsV/value_[run]_[iter].npy. We used as prev_iter iter-40 assuming the FQE model for iter-40 was already finished training. If models/q_model_[run]_[prev_iter].hdf5 exists, the model will use these weights as the starting value and halve the number of training epochs from 100 to 50. Otherwise, it will just use random starting weights and run for 100 epochs. The fitQEval.py script was designed to maximize computational and memory efficiency subject to loading the entire data into RAM at once. Greater memory efficiency can be obtained by avoiding loading all the data into memory at once and instead load partial data periodically from data/train. This is how we implemented SBV and modifying our fitQEval.py file to do this shouldn't be too difficult. Once the desired estimated expected returns are saved, we can then use genFQEDict.py to calculate FQE for the Q-networks saved in tmp/dqn, tmp/dqn_deep or both (if they exist), for every 40 iterations from 40 to 2000, and write the resulting output to a fqe_dict.csv file. It is easy enough to manipulate the script to calculate the FQE for different Q-functions or for different runs and iterations: All we require is that the files y_[run]_[iter] and q_policies_[run]_[iter] exist for given run run and given iterations iter, as calculated by calcTargets.sh.

SBV

Finally, SBV can be implemented by running the script runSBV.py within the Asterix, Breakout, Pong or Seaquest subdirectories. These scripts were used to calculate Bellman networks for the Asterix, Breakout, Pong and Seaquest experiments, respectively, as discussed in section 6.3 and Appendix section D.2 of our paper. The Bellman network training configuration is different between these environments, hence the different files. The Pong file is the most different because it uses a much simpler architecture. Moreover, for Pong, Asterix and Seaquest, we avoided loading all data into memory at once, instead loading in only a few files from shards/train at a time to improve memory efficiency. For Breakout, we found this strategy led to moderate increase in validation error (around 5%) and thus loaded the full data into memory instead. However, a similar memory-efficient implementation for Breakout can be executed with the runSBV_memory.py script. The script takes as arguments the run run and an arbitrary list of iterations iter1 iter2.... For example, running python runSBV.py ddqn_Srch 40 80 120 will train three Bellman network based on the Q-Networks of the dqn_deep run after iteration 40, 80 and 120. Moreover, for Pong, Seaquest and Breakout, the Bellman network for estimating dqn_deep iteration 40 will be used as starting value to estimate the Bellman network for estimating dqn_deep iteration 80, and similarly for dqn_deep iteration 120, to improve computational efficiency. To avoid overfitting, we avoided using starting weights from a previous iteration more than twice to avoid for Breakout and Seaquest and more than four times for Pong. So for instance, we would run python runSBV.py ddqn_Srch 40 80 120 and then python runSBV.py ddqn_Srch 160 200 240 to train a Bellman network for iterations 40-240 for Seaquest and Breakout. For Asterix, using any such warm starts led to moderate validation error increase, and thus running python runSBV.py ddqn_Srch 40 80 120 160 200 240 using the Asterix script will train Bellman networks for iterations 40-240 while using random starting weights for all six Bellman network.

When running python runSBV.py run iter1 iter2 ..., files models/run_iter1.hdf5, models/run_iters2.hdf5 etc will be written to save the trained weights after the best iteration and models/[run]_[iter1].csv, models/[run]_[iters2].csv etc will save information about the training progress. Finally, files targetsV/backups_[run]_[iter1].npy and targetsV/backups_[run]_[iter2].npy will be saved storing the estimated Bellman backups $\mathcal B_\phi(s_t,a_t)$ where $(s_t,a_t)$ are state-action pairs from the validation set and $\mathcal B_\phi$ is the trained Bellman network saved in the models directory. Once the desired estimated backups are saved, we can then use genSBVDict.py to calculate FQE for the Q-networks saved in tmp/dqn, tmp/dqn_deep or both (if they exist), and write the resulting output to a sbv_dict.csv file. It is easy enough to manipulate the runSBV.py and genSBVDict.py scripts to run SBV on different Q-functions or for different runs and iterations: All we require is that the files y_[run]_[iter] and q_vals_[run]_[iter] exist for given run run and given iterations iter, as calculated by calcTargets.sh.

Reproducing Atari Tables and Figures

To reproduce our tables and figures, we first use processData.sh to process data for the following subdirectories located in the bucket gs://atari-replay-datasets/dqn: Pong/1, Pong/2, Pong/4, Seaquest/1, Seaquest/2, Seaquest/4, Breakout/1, Breakout/2, Breakout/3, Asterix/2, Asterix/3 and Asterix/5. We didn't use datasets 1,2,4 for Breakout and Asterix because Asterix/1, Breakout/4 and Breakout/5 were all missing files when we first started our research. We then run sh runDqn.sh on Pong and both sh runDqn.sh and sh rundqn_deep.sh the remaining three games. We then run sh calcTargets.sh dqn on Pong and both sh calcTargets.sh dqn and sh calcTargets.sh dqn_deep on the remaining three environments. We then use trainPropensity.py and runSBV.py to train the neccesary propensity and Bellman networks. We then genDicts.sh to make the target_dict.csv, wis_dict.csv and sbv_dict.csv files. Finally, the associated Atari directories (e.g. Breakout1, Seaquest4, Asterix5) containing the files target_dict.csv, wis_dict.csv and sbv_dict.csv should all be moved within the same parent directory. From this parent directory, we can then run plotGames2.R and plotGames.R to generate Figures 4 and G.5 from our paper, respectively. We can also run atariOracle.sh to generate Figure G.4. Note that due to stochasticity of deep learning training, the observed figures and numbers may be somewhat different than those seen in our paper, but the general conclusion of SBV outperforming EMSBE, WIS and FQE should still be apparent.

To make reproducability easier, we also store the Q-functions analyzed in section 6.3 of our paper as well as thewis_dict.csv, sbv_dict.csv and target_dict.csv files used to generate Table 2 and Figures 4, G.4 and G.5. These objects can be found in the atari_objects directory of this repository. The FQE Results for Seaquest discussed in section 6.3 of our paper can also be found in the fqe_dict.csv file located in atari_objects/Seaquest1.

References

[1] Joshua P Zitovsky, Daniel de Marchi, Rishabh Agarwal and Michael R Kosorok. "Revisiting Bellman Errors for Offline Model Selection". arXiv abs/2302.00141.

[2] Rishabh Agarwal, Dale Schuurmans, and Mohammad Norouzi. "An Optimistic Perspective on Offline Reinforcement Learning". ICML, 2020

[3] Pablo S. Castro, Subhodeep Moitra, Carles Gelada, Saurabh Kumar and Marc G. Bellemare. "Dopamine: A Research Framework for Deep Reinforcement Learning". arXiv, abs/1812.06110, 2018.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published