This repository contains an implementation of a deep neural network architecture combining both graph neural networks (GNNs) and temporal convolutional networks (TCNs), which is able to learn from the spatial and temporal components of rs-fMRI data in an end-to-end fashion. Please check the publications at the end of this page for more details on how this architecture was used and evaluated.
If something is not clear or you have any question please open an Issue.
The code in this repository relies on Weights & Biases (W&B) to keep track and organise the results of experiments. W&B software was responsible to conduct the hyperparameter search, and all the sweeps (needed for hyperparameter search) used are defined in the wandb_sweeps/
folder. All our runs, sweep definitions and reports are publicly available at our project's W&B page. In particular, we provide two reports to briefly organise the main results of our experiments.
We recommend that a user wanting to run and extend our code first gets familiar with the online documentation. As an example, we would create a sweep by running the following command in a terminal:
$ wandb sweep --entity st-team wandb_sweeps/st_ukb_uni_gender_1_fmri_none_nodemeta_mean_128.yaml
Which yielded an identifier (in this case qqqjagns
), thus allowing us to run 25 random sweeps of our code by executing:
$ wandb agent st-team/spatio-temporal-brain/qqqjagns --count=25
The wandb agent will execute main_loop.py
with its set of hyperparameters (as defined in all the *.yaml
files inside the wandb_sweeps
folder). Note that we use a different sweep file for each cross validation fold.
The file meta_data/st_env.yml
contains the exact dependencies used to develop and run this repository. In order to install all the dependencies automatically with Miniconda or Anaconda, one can easily just run the following command in the terminal to create an Anaconda environment:
$ conda env create -f meta_data/st_env.yml
$ conda activate st_env
The main packages used by this repository are:
- matplotlib==3.1.3
- networkx==2.4
- pandas==1.0.2
- python==3.7
- pytorch==1.4.0
- scikit-learn==0.22.2
- seaborn==0.10.1
- torch-geometric==1.4.2
- wandb==0.8.31
The main entry point to understand how things work is the file executed by the wandb agent: main_loop.py
. This file includes all the code necessary to read the hyperparameters defined from the wandb agent and train a model accordingly. The files it needs are mostly in the root of this repository:
datasets.py
: Classes to load datasets into memory, specificallyHCPDataset
for the Human Connectome Project, andUKBDataset
for the UK Biobank. They all inherit fromBrainDataset
, which is created according to Pytorch Geometric'sInMemoryDataset
class.model.py
: where the main spatio-temporal model of this repository is, with the nameSpatioTemporalModel
, which is created according to different flags passed as arguments.tcn.py
: TCN adaptation, originally taken from: https://github.com/locuslab/TCN/blob/master/TCN/tcn.pyutils.py
: Many utility functions. Notice the enums defined at the very beginning (e.g., SweepType, Normalisation, DatasetType, etc), which represent the flags that can be defined by the wandb agent, or more generally in the code.utils_datasets.py
: Many constant variables specifically needed for dataset handling.
With regards to the folders in this repository:
encoders
: old files with implementations of autoencoders which we tried (and are defined as possible flags inutils.py
); however, the results using these models were never good and therefore do not appear in paper.meta_data
: some ("meta") files needed for this repository.outputs
: old output files from preliminary work using the Human Connectome Project (see publications ). Back then, we were not using Weights & Biases yet so we leave the output in this folderpost_analysis
: Variety of files mostly used for post analysis scripts used for publication regarding plotting and other analysis. The code is very messy here and some refactoring is meant to be done in the future.wandb_sweeps
: all the Weights & Biases configuration files
Data cannot be publicly shared in this repository, we are working on giving more information on that as soon as possible.
The architecture implemented in this repository is described in detail in a preprint at BioRxiv. If you use this architecture in your research work please cite the paper, with the following bibtex:
@article{Azevedo2020,
doi = {10.1101/2020.11.08.370288},
url = {https://doi.org/10.1101/2020.11.08.370288},
year = {2020},
month = nov,
publisher = {Cold Spring Harbor Laboratory},
author = {Tiago Azevedo and Alexander Campbell and Rafael Romero-Garcia and Luca Passamonti and Richard A.I. Bethlehem and Pietro Lio and Nicola Toschi},
title = {A Deep Graph Neural Network Architecture for Modelling Spatio-temporal Dynamics in resting-state functional {MRI} Data}
}
Two preliminary versions of this work were also presented in two other venues, which can be accessible online:
- A deep spatiotemporal graph learning architecture for brain connectivity analysis. EMBC 2020. DOI: 10.1109/EMBC44109.2020.9175360.
- Towards a predictive spatio-temporal representation of brain data. Ai4AH @ ICLR 2020. ArXiv: 2003.03290.