Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add metrics analysis excluding WBM materials with duplicate/MP-matching structure prototype #75

Merged
merged 9 commits into from
Jan 25, 2024
2 changes: 1 addition & 1 deletion .github/workflows/slow-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 3.11

- name: Install dependencies
run: pip install -e .[test]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 3.11

- name: Install package and dependencies
run: pip install -e .[fetch-data]
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ jobs:
uses: janosh/workflows/.github/workflows/pytest-release.yml@main
with:
os: ${{ matrix.os }}
python-version: 3.11
12 changes: 2 additions & 10 deletions data/figshare/1.0.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
"https://figshare.com/ndownloader/files/41233560",
"2023-06-02-pbenner-best-alignn-model.pth.zip"
],
"mace_checkpoint_1": [
"https://figshare.com/ndownloader/files/42374049",
"2023-08-14-mace-yuan-trained-mptrj-04.model"
],
"mace_checkpoint_2": [
"https://figshare.com/ndownloader/files/43117273",
"2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss.model"
],
"mp_computed_structure_entries": [
"https://figshare.com/ndownloader/files/40344436",
"2023-02-07-mp-computed-structure-entries.json.gz"
Expand Down Expand Up @@ -41,8 +33,8 @@
"2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
],
"wbm_summary": [
"https://figshare.com/ndownloader/files/41296866",
"2022-10-19-wbm-summary.csv.gz"
"https://figshare.com/ndownloader/files/44225498",
"2023-12-13-wbm-summary.csv.gz"
],
"mp_trj_extxyz_by_yuan": [
"https://figshare.com/ndownloader/files/43302033",
Expand Down
Binary file not shown.
81 changes: 59 additions & 22 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pymatviz.io import save_fig
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, today
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom

Expand All @@ -38,9 +38,6 @@
"""


module_dir = os.path.dirname(__file__)


# %% links to google drive files received via email from 1st author Hai-Chen Wang
# on 2021-06-15 containing initial and relaxed structures
google_drive_ids = {
Expand All @@ -53,10 +50,10 @@


# %%
os.makedirs(f"{module_dir}/raw", exist_ok=True)
os.makedirs(f"{WBM_DIR}/raw", exist_ok=True)

for step, file_id in google_drive_ids.items():
file_path = f"{module_dir}/raw/wbm-structures-step-{step}.json.bz2"
file_path = f"{WBM_DIR}/raw/wbm-structures-step-{step}.json.bz2"

if os.path.exists(file_path):
print(f"{file_path} already exists, skipping")
Expand All @@ -67,7 +64,7 @@


# %%
summary_path = f"{module_dir}/raw/wbm-summary.txt"
summary_path = f"{WBM_DIR}/raw/wbm-summary.txt"

if not os.path.exists(summary_path):
summary_id_file = "1639IFUG7poaDE2uB6aISUOi65ooBwCIg"
Expand All @@ -76,7 +73,7 @@


# %%
json_paths = sorted(glob(f"{module_dir}/raw/wbm-structures-step-*.json.bz2"))
json_paths = sorted(glob(f"{WBM_DIR}/raw/wbm-structures-step-*.json.bz2"))
step_lens = (61848, 52800, 79205, 40328, 23308)
# step 3 has 79,211 initial structures but only 79,205 ComputedStructureEntries
# i.e. 6 extra structures which have missing energy, volume, etc. in the summary file
Expand Down Expand Up @@ -177,7 +174,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
# "summary.txt.bz2",
*(f"step_{step}.json.bz2" for step in range(1, 6)),
):
file_path = f"{module_dir}/raw/wbm-cse-{filename.lower().replace('_', '-')}"
file_path = f"{WBM_DIR}/raw/wbm-cse-{filename.lower().replace('_', '-')}"
if os.path.exists(file_path):
print(f"{file_path} already exists, skipping")
continue
Expand All @@ -191,7 +188,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:


# %%
cse_step_paths = sorted(glob(f"{module_dir}/raw/wbm-cse-step-*.json.bz2"))
cse_step_paths = sorted(glob(f"{WBM_DIR}/raw/wbm-cse-step-*.json.bz2"))
assert len(cse_step_paths) == 5

"""
Expand Down Expand Up @@ -295,14 +292,14 @@ def increment_wbm_material_id(wbm_id: str) -> str:
"vol": "volume",
"e": Key.dft_energy,
"e_form": Key.e_form_wbm,
"e_hull": "e_above_hull_wbm",
"e_hull": Key.each_wbm,
"gap": Key.bandgap_pbe,
"id": Key.mat_id,
}
# WBM summary was shared twice, once on google drive, once on materials cloud
# download both and check for consistency
df_summary = pd.read_csv(
f"{module_dir}/raw/wbm-summary.txt", sep="\t", names=col_map.values()
f"{WBM_DIR}/raw/wbm-summary.txt", sep="\t", names=col_map.values()
).set_index(Key.mat_id)

df_summary_bz2 = pd.read_csv(
Expand Down Expand Up @@ -398,7 +395,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
),
):
cols = ["formula_from_cse", *cols] # type: ignore[list-item]
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{fname}.json.bz2")
df_wbm[cols].reset_index().to_json(f"{WBM_DIR}/{today}-wbm-{fname}.json.bz2")


# %%
Expand Down Expand Up @@ -589,18 +586,34 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
try:
from aviary.wren.utils import get_aflow_label_from_spglib

if Key.wyckoff not in df_wbm:
df_summary[Key.wyckoff] = None
# add Aflow-style Wyckoff labels for initial and relaxed structures
for key in (Key.init_wyckoff, Key.wyckoff):
if key not in df_wbm:
df_summary[key] = None

for idx, struct in tqdm(df_wbm[Key.init_struct].items(), total=len(df_wbm)):
if not pd.isna(df_summary.loc[idx, Key.wyckoff]):
# from initial structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx, Key.init_wyckoff]):
continue # Aflow label already computed
try:
struct = Structure.from_dict(struct)
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
df_summary.loc[idx, Key.init_wyckoff] = get_aflow_label_from_spglib(struct)
except Exception as exc:
print(f"{idx=} {exc=}")

# from relaxed structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx, Key.wyckoff]):
continue

try:
cse = df_wbm.loc[idx, Key.cse]
struct = Structure.from_dict(cse["structure"])
df_summary.loc[idx, Key.wyckoff] = get_aflow_label_from_spglib(struct)
except Exception as exc:
print(f"{idx=} {exc=}")

assert df_summary[Key.init_wyckoff].isna().sum() == 0
assert df_summary[Key.wyckoff].isna().sum() == 0
except ImportError:
print("aviary not installed, skipping Wyckoff label generation")
Expand All @@ -609,7 +622,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:


# %%
fingerprints_path = f"{module_dir}/site-stats.json.gz"
fingerprints_path = f"{WBM_DIR}/site-stats.json.gz"
suggest = "not found, run scripts/compute_struct_fingerprints.py to generate"
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
try:
Expand All @@ -621,16 +634,40 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
print(f"{fingerprints_path=} does not contain {fp_diff_col=}")


# %% mark WBM materials with matching prototype in MP or duplicate prototypes
# in WBM (keeping only the lowest energy one)
df_mp = pd.read_csv(DATA_FILES.mp_energies, index_col=0)

# mask WBM materials with matching prototype in MP
mask_proto_in_mp = df_summary[Key.wyckoff].isin(df_mp[Key.wyckoff])
# mask duplicate prototypes in WBM (keeping the lowest energy one)
mask_dupe_protos = df_summary.sort_values(by=[Key.wyckoff, Key.each_wbm]).duplicated(
subset=Key.wyckoff, keep="first"
)
assert sum(mask_proto_in_mp) == 11_175, f"{sum(mask_proto_in_mp)=:_}"
assert sum(mask_dupe_protos) == 32_784, f"{sum(mask_dupe_protos)=:_}"

df_summary[Key.uniq_proto] = ~(mask_proto_in_mp | mask_dupe_protos)
assert dict(df_summary[Key.uniq_proto].value_counts()) == {True: 215_488, False: 41_475}
assert list(df_summary.query(f"~{Key.uniq_proto}").head(5).index) == [
"wbm-1-7",
"wbm-1-8",
"wbm-1-15",
"wbm-1-20",
"wbm-1-33",
]


# %% write final summary data to disk (yeah!)
df_summary.round(6).to_csv(f"{module_dir}/{today}-wbm-summary.csv")
df_summary.round(6).to_csv(f"{WBM_DIR}/{today}-wbm-summary.csv.gz")


# %% only here to load data for later inspection
if False:
wbm_summary_path = f"{module_dir}/2022-10-19-wbm-summary.csv.gz"
wbm_summary_path = f"{WBM_DIR}/2022-10-19-wbm-summary.csv.gz"
df_summary = pd.read_csv(wbm_summary_path).set_index(Key.mat_id)
df_wbm = pd.read_json(
f"{module_dir}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
f"{WBM_DIR}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
).set_index(Key.mat_id)

df_wbm["cse"] = [
Expand Down
12 changes: 11 additions & 1 deletion data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@
)


# %% print prevalence of stable structures in full WBM and uniq-prototypes only
for df, label in (
(df_wbm, "full WBM"),
(df_wbm.query(Key.uniq_proto), "WBM unique prototypes"),
):
n_stable = sum(df[Key.each_true] <= STABILITY_THRESHOLD)
stable_rate = n_stable / len(df)
print(f"{label}: {stable_rate=:.1%} ({n_stable:,} out of {len(df):,})")


# %%
for dataset, count_mode, elem_counts in all_counts:
filename = f"{dataset}-element-counts-by-{count_mode}"
Expand Down Expand Up @@ -303,7 +313,7 @@


# %%
df_wbm[Key.spacegroup] = df_wbm[Key.wyckoff].str.split("_").str[2].astype(int)
df_wbm[Key.spacegroup] = df_wbm[Key.init_wyckoff].str.split("_").str[2].astype(int)
df_mp[Key.spacegroup] = df_mp[Key.wyckoff].str.split("_").str[2].astype(int)


Expand Down
2 changes: 1 addition & 1 deletion data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The full set of processing steps used to curate the WBM test set from the raw da
- remove 6 pathological structures (with 0 volume)
- remove formation energy outliers below -5 and above 5 eV/atom (502 and 22 crystals respectively out of 257,487 total, including an anomaly of 500 structures at exactly -10 eV/atom)

<caption>WBM Formation energy distribution. 524 materials outside dashed lines were discarded.<br />(zoom out on this plot to see discarded samples)</caption>
<caption style="margin: 1em;">WBM Formation energy distribution. 524 materials outside dashed lines were discarded.</caption>
<slot name="hist-e-form-per-atom">
<img src="./figs/hist-wbm-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
</slot>
Expand Down
13 changes: 7 additions & 6 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import warnings
from datetime import datetime
from enum import Enum, unique
from enum import StrEnum, unique
from importlib.metadata import Distribution

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -54,10 +54,6 @@
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")


class StrEnum(str, Enum):
"""Enum whose members are also (and must be) strings."""


@unique
class Key(StrEnum):
"""Keys used to access dataframes columns."""
Expand All @@ -72,8 +68,10 @@ class Key(StrEnum):
e_form_pred = "e_form_per_atom_pred"
e_form_raw = "e_form_per_atom_uncorrected"
e_form_wbm = "e_form_per_atom_wbm"
each = "energy_above_hull" # as returned by MP API
each_pred = "e_above_hull_pred"
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
each_wbm = "e_above_hull_wbm"
final_struct = "relaxed_structure"
forces = "forces"
form_energy = "formation_energy_per_atom"
Expand All @@ -91,8 +89,11 @@ class Key(StrEnum):
stress_trace = "stress_trace"
struct = "structure"
task_id = "task_id"
# lowest WBM structures for a given prototype that isn't already in MP
uniq_proto = "unique_prototype"
volume = "volume"
wyckoff = "wyckoff_spglib"
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label


@unique
Expand Down
Loading
Loading