Skip to content

[MIUA 2024 Oral] JointViT: Modeling Oxygen Saturation Levels with Joint Supervision on Long-Tailed OCTA

Notifications You must be signed in to change notification settings

steve-zeyu-zhang/JointViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JointViT: Modeling Oxygen Saturation Levels with Joint Supervision on Long-Tailed OCTA
MIUA 2024 Oral

Zeyu Zhang, Xuyin Qi, Mingxi Chen, Guangxi Li, Ryan Pham, Ayub Qassim, Ella Berry, Zhibin Liao, Owen Siggs, Robert Mclaughlin, Jamie Craig, Minh-Son To

Website DOI arXiv Papers With Code BibTeX

The oxygen saturation level in the blood (SaO2) is crucial for health, particularly in relation to sleep-related breathing disorders. However, continuous monitoring of SaO2 is time-consuming and highly variable depending on patients' conditions. Recently, optical coherence tomography angiography (OCTA) has shown promising development in rapidly and effectively screening eye-related lesions, offering the potential for diagnosing sleep-related disorders. To bridge this gap, our paper presents three key contributions. Firstly, we propose JointViT, a novel model based on the Vision Transformer architecture, incorporating a joint loss function for supervision. Secondly, we introduce a balancing augmentation technique during data preprocessing to improve the model's performance, particularly on the long-tail distribution within the OCTA dataset. Lastly, through comprehensive experiments on the OCTA dataset, our proposed method significantly outperforms other state-of-the-art methods, achieving improvements of up to 12.28% in overall accuracy. This advancement lays the groundwork for the future utilization of OCTA in diagnosing sleep-related disorders.

main

News

(06/18/2024) ๐ŸŽ‰ Our paper has been selected as an oral presentation at MIUA 2024!

(05/14/2024) ๐ŸŽ‰ Our paper has been accepted to MIUA 2024!

Hardware

NVIDIA GeForce GTX TITAN X

Environment

For docker container:

docker pull qiyi007/oct:1.0

For dependencies:

conda create -n jointvit
conda activate jointvit
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=12.0 -c pytorch

Dataset

File directories as follows(3fold)

|-- OCT-Code
|   |-- util
       |-- ...
|   |-- vit_pytorch
       |-- ...
|   |-- train.py
|   |-- train_0 <split fold index information file>
|   |-- train_1 <split fold index information file>
|   |-- train_2 <split fold index information file>
|   |-- test_0 <split fold index information file>
|   |-- test_1 <split fold index information file>
|   |-- test_2 <split fold index information file>
|   |-- Sleep-results.xlsx <label information file>
|   |-- images-with-labels <dataset folder>
|       |-- ...

OR:

|-- OCT-Code
|   |-- util
       |-- ...
|   |-- vit_pytorch
       |-- ...
|   |-- train.py
|   |-- train_data <dataset folder>
|       |-- ...
|   |-- test_data <dataset folder>
|       |-- ...

Checkpoints

Training

For colab users, you can find the notebook.

modify args dict in train.py

 args = {
                    'device': torch.device("cuda:1"),
                    # 'model': get_model_octa_resume(outsize=5, path='ckpt_path', dropout=0.15),
                    # 'model': get_model_conv(pretrain_out=4,outsize=5, path='/OCT-Covid/covid_ckpts/oct4class_biglr/val_acc0.9759836196899414.pt'),
                    'model': get_vani(outsize=5, dropout=0.25),
                    # 'model': get_model_oct_withpretrain(pretrain_out=4,outsize=5, path='/OCT-Covid/covid_ckpts/oct4class_biglr/val_acc0.9759836196899414.pt', dropout=0.15),
                    'save_path': 'save_path', 
                    'bce_weight': 1,     
                    'epochs': 200, 
                    'lr': lr, 
                    'batch_size': 300, 
                    'datasets': get_dataUNI(split_idx=split, aug_class=isaug, bal_val = isbalval),
                    'vote_loader': DataLoader(get_dataUNI(split_idx=split, aug_class=isaug, bal_val = isbalval, infer_3d=True)[1], batch_size=1, shuffle=False),
                    'is_echo': False,
                    'optimizer': optim.Adam,
                    'scheduler': optim.lr_scheduler.CosineAnnealingLR,
                    'train_loader': None,
                    'eval_loader': None,
                    'shuffle': True,
                    'is_MIX': True, # use mixloss input
                    'wandb': ['wandb account','project name',run_name],
                    'decay': 1e-3,
                }

3-fold training & validation:

python train_3fold.py

Default training:

python train.py

After running train.py, the metrics for a test fold will be displayed.

Citation

@inproceedings{zhang2024jointvit,
  title={Jointvit: Modeling oxygen saturation levels with joint supervision on long-tailed octa},
  author={Zhang, Zeyu and Qi, Xuyin and Chen, Mingxi and Li, Guangxi and Pham, Ryan and Qassim, Ayub and Berry, Ella and Liao, Zhibin and Siggs, Owen and Mclaughlin, Robert and others},
  booktitle={Annual Conference on Medical Image Understanding and Analysis},
  pages={158--172},
  year={2024},
  organization={Springer}
}