-
Notifications
You must be signed in to change notification settings - Fork 17
/
STELLAR_run.py
45 lines (41 loc) · 2.25 KB
/
STELLAR_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
from utils import prepare_save_dir
from STELLAR import STELLAR
import numpy as np
import os
import torch
from datasets import GraphDataset, load_tonsilbe_data, load_hubmap_data
def main():
parser = argparse.ArgumentParser(description='STELLAR')
parser.add_argument('--dataset', default='Hubmap', help='dataset setting')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--name', type=str, default='STELLAR')
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--wd', type=float, default=5e-2)
parser.add_argument('--num-heads', type=int, default=22)
parser.add_argument('--num-seed-class', type=int, default=0)
parser.add_argument('--sample-rate', type=float, default=0.5)
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N',
help='mini-batch size')
parser.add_argument('--distance_thres', default=50, type=int)
parser.add_argument('--savedir', type=str, default='./')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
args.device = torch.device("cuda" if args.cuda else "cpu")
# Seed the run and create saving directory
args.name = '_'.join([args.dataset, args.name])
args = prepare_save_dir(args, __file__)
if args.dataset == 'Hubmap':
labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges, inverse_dict = load_hubmap_data('./data/B004_training_dryad.csv', './data/B0056_unnanotated_dryad.csv', args.distance_thres, args.sample_rate)
dataset = GraphDataset(labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges)
elif args.dataset == 'TonsilBE':
labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges, inverse_dict = load_tonsilbe_data('./data/BE_Tonsil_l3_dryad.csv', args.distance_thres, args.sample_rate)
dataset = GraphDataset(labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges)
stellar = STELLAR(args, dataset)
stellar.train()
_, results = stellar.pred()
np.save(os.path.join(args.savedir, args.dataset + '_results.npy'), results)
if __name__ == '__main__':
main()