Skip to content

Commit

Permalink
model_figs/__init__.py use init_globals to run several scripts with d…
Browse files Browse the repository at this point in the history
…ifferent parameter values
  • Loading branch information
janosh committed Dec 24, 2023
1 parent 8dd620f commit 5d93908
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
23 changes: 17 additions & 6 deletions scripts/model_figs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,34 @@

import plotly.graph_objects as go
from dash import Dash
from tqdm import tqdm

__author__ = "Janosh Riebesell"
__date__ = "2023-07-14"

module_dir = os.path.dirname(__file__)

# monkey patch go.Figure.show() and Dash.run() to prevent them opening a browser
# monkey patch go.Figure.show() and Dash.run() to prevent them from opening browser
go.Figure.show = lambda self, *args, **kwargs: None
Dash.run = lambda self, *args, **kwargs: None

# subtract __file__ to avoid this file calling itself
scripts = set(glob(f"{module_dir}/*.py")) - {__file__}


# %%
for file in glob(f"{module_dir}/*.py"):
if file == __file__: # skip this file
continue
print(f"Running {file.split(os.path.sep)[-1]}...")
for file in (pbar := tqdm(scripts)):
pbar.set_description(file)
try:
runpy.run_path(file)
if file.endswith("parity_energy_models.py"):
for which_energy in ("each", "e-form"):
runpy.run_path(file, init_globals={"which_energy": which_energy})
elif file.endswith("cumulative_metrics.py"):
for metrics in (("MAE",), ("Precision", "Recall")):
runpy.run_path(file, init_globals={"metrics": metrics})
elif file.endswith("rolling_mae_vs_hull_dist_wbm_batches.py"):
runpy.run_path(file, init_globals={"models": ("CHGNet", "MACE")})
else:
runpy.run_path(file)
except Exception as exc:
print(f"{file!r} failed: {exc}")
4 changes: 2 additions & 2 deletions scripts/model_figs/cumulative_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


# %%
# metrics = ("Precision", "Recall")
metrics = ("MAE",)
metrics: tuple[str, ...] = globals().get("metrics", ("Precision", "Recall"))
# metrics = ("MAE",)
range_y = {
("MAE",): (0, 0.7),
("Precision", "Recall"): (0, 1),
Expand Down
2 changes: 1 addition & 1 deletion scripts/model_figs/parity_energy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
legend = dict(x=1, y=0, xanchor="right", yanchor="bottom", title=None)

# toggle between formation energy and energy above convex hull
which_energy: Literal["e-form", "each"] = "each"
which_energy: Literal["e-form", "each"] = globals().get("which_energy", "each")
if which_energy == "each":
e_pred_col = each_pred_col
e_true_col = each_true_col
Expand Down

0 comments on commit 5d93908

Please # to comment.