Skip to content

Commit

Permalink
Fix test scripts raising wandb.Table AssertionError: columns argu…
Browse files Browse the repository at this point in the history
…ment expects list of strings or ints (#82)

* fix model test scripts raising wandb.Table AssertionError: columns argument expects list of strings or ints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
janosh and pre-commit-ci[bot] authored Jan 30, 2024
1 parent 3118330 commit 3f6c798
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 14 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,3 @@ site/src/routes/api/*.md

# temporary ignore rules
2023-02-05-ens=10-perturb=5
models/m3gnet/2023-05-26-*
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
rev: v0.1.15
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0-alpha.0
rev: v9.0.0-alpha.2
hooks:
- id: eslint
types: [file]
Expand Down
3 changes: 2 additions & 1 deletion matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,5 @@ def _on_not_found(self, key: str, path: str) -> None: # type: ignore[override]


df_wbm = load("wbm_summary")
df_wbm[Key.mat_id] = df_wbm.index
# str() around Key.mat_id added for https://github.com/janosh/matbench-discovery/issues/81
df_wbm[str(Key.mat_id)] = df_wbm.index
4 changes: 1 addition & 3 deletions models/alignn_ff/test_alignn_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@


# %%
df_wbm = df_wbm.dropna()

table = wandb.Table(dataframe=df_wbm[[Key.e_form, pred_col]].reset_index())
table = wandb.Table(dataframe=df_wbm[[Key.e_form, pred_col]].reset_index().dropna())

MAE = (df_wbm[Key.e_form] - df_wbm[pred_col]).abs().mean()
R2 = r2_score(df_wbm[Key.e_form], df_wbm[pred_col])
Expand Down
2 changes: 1 addition & 1 deletion models/chgnet/test_chgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
# %%
df_wbm[e_pred_col] = df_out[e_pred_col]
table = wandb.Table(
dataframe=df_wbm.dropna()[[Key.dft_energy, e_pred_col, Key.formula]].reset_index()
dataframe=df_wbm[[Key.dft_energy, e_pred_col, Key.formula]].reset_index().dropna()
)

title = f"CHGNet {task_type} ({len(df_out):,})"
Expand Down
8 changes: 3 additions & 5 deletions models/mace/test_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
device = "cuda" if torch.cuda.is_available() else "cpu"
# whether to record intermediate structures into pymatgen Trajectory
record_traj = False # has no effect if relax_cell is False
model_name = [
"2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss",
"https://tinyurl.com/y7uhwpje",
][-1]
model_name = "https://tinyurl.com/y7uhwpje"
ase_filter: Literal["frechet", "exp"] = "frechet"

slurm_vars = slurm_submit(
Expand Down Expand Up @@ -163,8 +160,9 @@

# %%
df_wbm[e_pred_col] = df_out[e_pred_col]

table = wandb.Table(
dataframe=df_wbm.dropna()[[Key.dft_energy, e_pred_col, Key.formula]].reset_index()
dataframe=df_wbm[[Key.dft_energy, e_pred_col, Key.formula]].reset_index().dropna()
)

title = f"MACE {task_type} ({len(df_out):,})"
Expand Down
2 changes: 1 addition & 1 deletion scripts/model_figs/make_metrics_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
styler = (
df_filtered.style.format(
# render integers without decimal places
{key: "{:,.0f}" for key in "TP FN FP TN".split()},
dict.fromkeys("TP FN FP TN".split(), "{:,.0f}"),
precision=2, # render floats with 2 decimals
na_rep="", # render NaNs as empty string
)
Expand Down

0 comments on commit 3f6c798

Please # to comment.