-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfit.py
executable file
·77 lines (69 loc) · 2.33 KB
/
fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
import stats
from mgs import Enrichment, MGSData, target_bioprojects
from pathogens import predictors_by_taxid
def summarize_output(coeffs: pd.DataFrame) -> pd.DataFrame:
return coeffs.groupby(
[
"pathogen",
"tidy_name",
"taxids",
"predictor_type",
"study",
"location",
]
).ra_at_1in100.describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])
def start(num_samples: int, plot: bool) -> None:
figdir = Path("fig")
if plot:
figdir.mkdir(exist_ok=True)
mgs_data = MGSData.from_repo()
input_data = []
output_data = []
for (
pathogen_name,
tidy_name,
predictor_type,
taxids,
predictors,
) in predictors_by_taxid():
taxids_str = "_".join(str(t) for t in taxids)
for study, bioprojects in target_bioprojects.items():
enrichment = None if study == "brinch" else Enrichment.VIRAL
model = stats.build_model(
mgs_data,
bioprojects,
predictors,
taxids,
random_seed=sum(taxids),
enrichment=enrichment,
)
if model is None:
continue
model.fit_model(num_samples=num_samples)
if plot:
taxid_str = "-".join(str(tid) for tid in taxids)
model.plot_figures(
path=figdir,
prefix=f"{pathogen_name}-{taxid_str}-{predictor_type}-{study}",
)
metadata = dict(
pathogen=pathogen_name,
tidy_name=tidy_name,
taxids=taxids_str,
predictor_type=predictor_type,
study=study,
)
input_data.append(model.input_df.assign(**metadata))
output_data.append(model.get_coefficients().assign(**metadata))
input = pd.concat(input_data)
input.to_csv("input.tsv", sep="\t", index=False)
coeffs = pd.concat(output_data)
coeffs.to_csv("fits.tsv", sep="\t", index=False)
summary = summarize_output(coeffs)
summary.to_csv("fits_summary.tsv", sep="\t")
if __name__ == "__main__":
# TODO: Command line arguments
start(num_samples=1000, plot=True)