Skip to content

Commit

Permalink
add site/src/figs/(largest-fp-diff-each-error-models|largest-each-err…
Browse files Browse the repository at this point in the history
…ors-fp-diff-models).svelte

shown on /models/tmi page (ex /models/per-element)
generated by scripts/difficult_structures.py

add col site_stats_fingerprint_init_final_norm_diff to data/wbm/2022-10-19-wbm-summary.csv

sort EACH scatter plot facets and legend by MAE
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 7946b5e commit 4b6e83a
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# exclude generated plot files when calculating repo language statistics on GitHub
*/figs/* linguist-generated
**/figs/* linguist-generated
data/**/*.svelte linguist-generated
4 changes: 2 additions & 2 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from aviary.wren.utils import get_aflow_label_from_spglib
from mp_api.client import MPRester
from pymatviz.utils import annotate_mae_r2
from pymatviz.utils import annotate_metrics
from tqdm import tqdm

from matbench_discovery import today
Expand Down Expand Up @@ -71,7 +71,7 @@
title=f"{today} - {len(df):,} MP entries",
)

annotate_mae_r2(df.formation_energy_per_atom, df.decomposition_enthalpy)
annotate_metrics(df.formation_energy_per_atom, df.decomposition_enthalpy)
# result on 2023-01-10: plots match. no correlation between formation energy and
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
# ax.figure.savefig(f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
Expand Down
13 changes: 13 additions & 0 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from glob import glob

import numpy as np
import pandas as pd
from aviary.wren.utils import get_aflow_label_from_spglib
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
Expand Down Expand Up @@ -597,6 +598,18 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
assert df_summary[wyckoff_col].isna().sum() == 0


# %% site-stats.json.gz was generated by scripts/compute_struct_fingerprints.py
df_fp = pd.read_json(f"{module_dir}/site-stats.json.gz").set_index("material_id")
init_fp_col = "initial_site_stats_fingerprint"
final_fp_col = "final_site_stats_fingerprint"
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
df_fp[fp_diff_col] = (
df_fp[final_fp_col].map(np.array) - df_fp[init_fp_col].map(np.array)
).map(np.linalg.norm)

df_fp[fp_diff_col].hist(bins=100, backend="plotly")


# %% write final summary data to disk (yeah!)
df_summary.round(6).to_csv(f"{module_dir}/{today}-wbm-summary.csv")

Expand Down
28 changes: 27 additions & 1 deletion scripts/compute_struct_fingerprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# %%
import os
import warnings
from glob import glob

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,6 +47,7 @@
account="LEE-SL3-CPU",
time="6:0:0",
array=f"1-{slurm_array_task_count}",
slurm_flags=("--mem", "30G"),
)


Expand Down Expand Up @@ -95,4 +97,28 @@
except Exception as exc:
print(f"{fp_col} for {row.Index} failed: {exc}")

df_in.filter(like="site_stats_fingerprint").to_json(out_path)
df_in.filter(like="site_stats_fingerprint").reset_index().to_json(out_path)


# %%
running_as_slurm_job = os.getenv("SLURM_JOB_ID")
if running_as_slurm_job:
print(f"Job wrote {out_path=} and finished at {timestamp}")
raise SystemExit(0)


# %%
out_files = glob(f"{out_dir}/site-stats-*.json.gz")

found_idx = [int(name.split("-")[-1].split(".")[0]) for name in out_files]
print(f"Found {len(out_files)=:,}")
missing_files = sorted(set(range(1, slurm_array_task_count + 1)) - set(found_idx))
if missing_files:
print(f"{len(missing_files)=}: {missing_files}")

df_out = pd.concat(pd.read_json(out_file) for out_file in tqdm(out_files))


df_out.index.name = "material_id"

df_out.reset_index().to_json(f"{out_dir}/site-stats.json.gz")
Loading

0 comments on commit 4b6e83a

Please # to comment.