Skip to content

Commit

Permalink
add columns wyckoff_spglib and spacegroup_symbol to data/mp/2023-01-1…
Browse files Browse the repository at this point in the history
…0-mp-energies.csv

add spacegroup sunburst plots for wbm and mp to /about-the-data page and to Wrenformer failure cases analysis

add df_to_svelte_table() in plots.py
add prototype prevalence analysis to wrenformer failure cases

add dummy row to metrics table (combining metrics from a dummy classifier and a dummy regressor) and move R2 to last col, add dotted line between reg and clf metrics

add rolling MAE vs hull dist by WBM batches plots for all models

add favicon-black.pdf

change interatomic potential abbreviation IAP -> UIP to highlight the universal aspect (full periodic table)
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 457abf0 commit 01e88f0
Show file tree
Hide file tree
Showing 39 changed files with 921 additions and 321 deletions.
2 changes: 1 addition & 1 deletion citation.cff
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ type: software
url: https://github.com/janosh/matbench-discovery
doi: TODO
version: 1.0.0 # replace with whatever Matbench Discovery version you use
date-released: TODO
date-released: date TBD
27 changes: 23 additions & 4 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
from aviary.wren.utils import get_aflow_label_from_spglib
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatviz.utils import annotate_metrics
from tqdm import tqdm

Expand Down Expand Up @@ -33,6 +34,7 @@
"energy_above_hull",
"decomposition_enthalpy",
"energy_type",
"symmetry",
}

with MPRester(use_document_model=False) as mpr:
Expand All @@ -47,17 +49,34 @@
# %%
df = pd.DataFrame(docs).set_index("material_id")

df_spg = pd.json_normalize(df.pop("symmetry"))[["number", "symbol"]]
df["spacegroup_symbol"] = df_spg.symbol.values

df.energy_type.value_counts().plot.pie(backend="plotly", autopct="%1.1f%%")
# GGA: 72.2%, GGA+U: 27.8%


# %%
df["spacegroup_number"] = [x["number"] for x in df.pop("symmetry")]

df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index("material_id")

df_cse["structure"] = [
Structure.from_dict(cse["structure"]) for cse in tqdm(df_cse.entry)
]
wyk_col = "wyckoff_spglib"
df_cse[wyk_col] = [
get_aflow_label_from_spglib(struct, errors="ignore")
for struct in tqdm(df_cse.structure)
]
# make sure symmetry detection succeeded for all structures
assert df_cse[wyk_col].str.startswith("invalid").sum() == 0
df[wyk_col] = df_cse[wyk_col]

spg_nums = df[wyk_col].str.split("_").str[2].astype(int)
# make sure all our spacegroup numbers match MP's
assert (spg_nums.sort_index() == df_spg["number"].sort_index()).all()

df.to_csv(DATA_FILES.mp_energies)
# df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False)
# df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")


# %% reproduce fig. 1b from https://arxiv.org/abs/2001.10591 (as data consistency check)
Expand Down
52 changes: 41 additions & 11 deletions data/wbm/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pandas as pd
import plotly.express as px
from pymatgen.core import Composition
from pymatviz import count_elements, ptable_heatmap_plotly
from pymatviz import (
count_elements,
ptable_heatmap_plotly,
spacegroup_sunburst,
)
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, ROOT, today
Expand All @@ -31,7 +35,15 @@

# wbm_elem_counts.to_json(f"{about_data_page}/wbm-element-counts.json")

# export element counts by WBM step to JSON

# %% load MP training set
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False)
mp_elem_counts = count_elements(df_mp.formula_pretty).astype(int)

# mp_elem_counts.to_json(f"{about_data_page}/mp-element-counts.json")


# %% export element counts by WBM step to JSON
df_wbm["step"] = df_wbm.index.str.split("-").str[1].astype(int)
assert df_wbm.step.between(1, 5).all()
for batch in range(1, 6):
Expand All @@ -43,8 +55,8 @@
comp_col = "composition"
df_wbm[comp_col] = df_wbm.formula.map(Composition)

for arity, df in df_wbm.groupby(df_wbm[comp_col].map(len)):
count_elements(df.formula).to_json(
for arity, df_mp in df_wbm.groupby(df_wbm[comp_col].map(len)):
count_elements(df_mp.formula).to_json(
f"{about_data_page}/wbm-element-counts-{arity=}.json"
)

Expand Down Expand Up @@ -73,13 +85,6 @@
# save_fig(wbm_fig, f"{FIGS}/wbm-elements.svelte")


# %% load MP training set
df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False)
mp_elem_counts = count_elements(df.formula_pretty).astype(int)

# mp_elem_counts.to_json(f"{about_data_page}/mp-element-counts.json")


# %%
mp_fig = ptable_heatmap_plotly(
mp_elem_counts[mp_elem_counts > 1],
Expand Down Expand Up @@ -233,3 +238,28 @@
f"{color_col}: %{{customdata[3]:.2f}}<br>"
)
fig.show()


# %%
wyk_col, spg_col = "wyckoff_spglib", "spacegroup"
df_wbm[spg_col] = df_wbm[wyk_col].str.split("_").str[2].astype(int)
df_mp[spg_col] = df_mp[wyk_col].str.split("_").str[2].astype(int)


# %%
fig = spacegroup_sunburst(df_wbm[spg_col], width=350, height=350, show_counts="percent")
fig.layout.title.update(text="WBM Spacegroup Sunburst", x=0.5, font_size=14)
fig.show()
save_fig(fig, f"{FIGS}/spacegroup-sunburst-wbm.svelte")


# %%
fig = spacegroup_sunburst(df_mp[spg_col], width=350, height=350, show_counts="percent")
fig.layout.title.update(text="MP Spacegroup Sunburst", x=0.5, font_size=14)
fig.show()
save_fig(fig, f"{FIGS}/spacegroup-sunburst-mp.svelte")
# would be good to have consistent order of crystal systems between sunbursts but not
# controllable yet
# https://github.com/plotly/plotly.py/issues/4115
# https://github.com/plotly/plotly.js/issues/5341
# https://github.com/plotly/plotly.js/issues/4728
28 changes: 17 additions & 11 deletions data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ materialscloud:2021.68 includes a readme file with a description of the dataset,

[wbm paper]: https://nature.com/articles/s41524-020-00481-6

## 🎯 &thinsp; Target Distribution

The WBM test set has an energy above the MP convex hull distribution with **mean ± std = 0.02 ± 0.25 eV/atom**.

The dummy MAE of always predicting the test set mean is **0.12 eV/atom**.

The number of stable materials (according to the MP convex hull which is spanned by the training data the models have access to) is **97k** out of **257k**, resulting in a dummy stability hit rate of **37%**.

> Note: [According to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), the stability rate w.r.t. to the more complete hull constructed from the combined train and test set (MP + WBM) for the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
<slot name="wbm-each-hist">
<img src="./figs/wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">
</slot>

## 🧪 &thinsp; Chemical Diversity

The WBM test set and even more so the MP training set are heavily oxide dominated. The WBM test set is about 75% larger than the MP training set and also more chemically diverse, containing a higher fraction of transition metals, post-transition metals and metalloids. Our goal in picking such a large diverse test set is future-proofing. Ideally, this data will provide a challenging materials discovery test bed even for large foundational ML models in the future.
Expand All @@ -88,16 +102,8 @@ Element counts for MP training set consisting of 146,323 `ComputedStructureEntri
<img src="./figs/mp-elements.svg" alt="Periodic table log heatmap of MP elements">
</slot>

## 🎯 &thinsp; Target Distribution

The WBM test set has an energy above the MP convex hull distribution with **mean ± std = 0.02 ± 0.25 eV/atom**.

The dummy MAE of always predicting the test set mean is **0.17 eV/atom**.
## 📊 &thinsp; Symmetry Statistics

The number of stable materials (according to the MP convex hull which is spanned by the training data the models have access to) is **97k** out of **257k**, resulting in a dummy stability hit rate of **37%**.
With one exception, MP and WBM have diverse representation across all 7 crystal systems. In MP, monoclinic (23%) and orthorhombic (21%) are most prevalent. In WBM, orthorhombic and tetragonal each make up 20%. Triclinic crystals are notably almost absent from WBM at just 1% prevalence, but well represented in MP (15%). Combined with the higher share of cubic structures in WBM (19% vs 14%), WBM structures have overall higher symmetry. This should benefit a model like Wrenformer reliant on symmetries to encode coarse-grained structural features. See [SI](/si#spacegroup-prevalence-in-wrenformer-failure-cases) for a failure case of this featurization.

> Note: [According to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), the stability rate w.r.t. to the more complete hull constructed from the combined train and test set (MP + WBM) for the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
<slot name="wbm-each-hist">
<img src="./figs/wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">
</slot>
<slot name="spacegroup-sunbursts" />
2 changes: 1 addition & 1 deletion matbench_discovery/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def stable_metrics(

return dict(
F1=2 * (precision * recall) / (precision + recall),
R2=r2_score(each_true, each_pred),
DAF=precision / prevalence,
Precision=precision,
Recall=recall,
**dict(TPR=TPR, FPR=FPR, TNR=TNR, FNR=FNR),
Accuracy=(n_true_pos + n_true_neg) / len(each_true),
MAE=np.abs(each_true - each_pred).mean(),
RMSE=((each_true - each_pred) ** 2).mean() ** 0.5,
R2=r2_score(each_true, each_pred),
)
36 changes: 36 additions & 0 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal

import matplotlib.pyplot as plt
Expand All @@ -17,6 +18,7 @@
import scipy.stats
import wandb
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from pandas.io.formats.style import Styler
from tqdm import tqdm

from matbench_discovery.metrics import classify_stable
Expand Down Expand Up @@ -805,3 +807,37 @@ def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) ->
)

wandb.log({"true_pred_scatter": scatter_plot})


def df_to_svelte_table(
styler: Styler,
file_path: str | Path,
inline_props: str = "",
styles: str = "",
**kwargs: Any,
) -> None:
"""Convert a pandas Styler to a svelte table.
Args:
styler (Styler): Styler object to convert.
file_path (str): Path to the file to write the svelte table to.
inline_props (str): Inline props to pass to the table element.
styles (str): CSS rules to add to the table styles.
**kwargs: Keyword arguments passed to Styler.to_html().
"""
# insert svelte {...props} forwarding to the table element
script = f"""
<script lang="ts">
import {{ sortable }} from 'svelte-zoo/actions'
</script>
<table use:sortable {inline_props} {{...$$props}}
"""

html_table = (
styler.to_html(**kwargs)
.replace("<table", script)
.replace("</style>", f"{styles}</style>")
)
with open(file_path, "w") as file:
file.write(html_table)
6 changes: 2 additions & 4 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ def load_df_wbm_with_preds(
df_metrics = pd.DataFrame()
df_metrics.index.name = "model"
for model in PRED_FILES:
df_metrics[model] = stable_metrics(
df_preds[each_true_col],
df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col],
)
each_pred = df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col]
df_metrics[model] = stable_metrics(df_preds[each_true_col], each_pred)

# pick F1 as primary metric to sort by
df_metrics = df_metrics.round(3).sort_values("F1", axis=1, ascending=False)
Expand Down
52 changes: 43 additions & 9 deletions models/wrenformer/analyze_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@


# %%
import pandas as pd
from aviary.wren.utils import get_isopointal_proto_from_aflow
from pymatviz import spacegroup_hist, spacegroup_sunburst
from pymatviz.ptable import ptable_heatmap_plotly
from pymatviz.utils import save_fig

from matbench_discovery import ROOT
from matbench_discovery.data import df_wbm
from matbench_discovery import FIGS, ROOT
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.plots import df_to_svelte_table
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col

__author__ = "Janosh Riebesell"
Expand All @@ -27,21 +30,52 @@
wyk_col = "wyckoff_spglib"
df_wbm[spg_col] = df_wbm[wyk_col].str.split("_").str[2].astype(int)
df_bad = df_wbm.loc[bad_ids]
title = f"{len(df_bad)} {model} preds with {max_each_true=}, {min_each_pred=}"
title = f"{len(df_bad)} {model} preds<br>with {max_each_true=}, {min_each_pred=}"


# %%
df_mp = pd.read_csv(DATA_FILES.mp_energies).set_index("material_id")
df_mp[spg_col] = df_mp[wyk_col].str.split("_").str[2].astype(int)
df_mp["isopointal_proto_from_aflow"] = df_mp[wyk_col].map(
get_isopointal_proto_from_aflow
)
df_mp.isopointal_proto_from_aflow.value_counts().head(12)


# %%
ax = spacegroup_hist(df_bad[spg_col])
fig = spacegroup_sunburst(df_bad[spg_col])
fig.layout.title = f"Spacegroup sunburst for {title}"
ax.set_title(f"Spacegroup hist for {title}", y=1.15)
out_dir = f"{ROOT}/tmp/figures"
save_fig(fig, f"{out_dir}/bad-{model}-spacegroup-sunburst.png", scale=3)
save_fig(ax, f"{out_dir}/bad-{model}-spacegroup-hist.png", dpi=300)
# save_fig(ax, f"{ROOT}/tmp/figures/spacegroup-hist-{model}-failures.png", dpi=300)


# %%
proto_col = "Isopointal Prototypes in Shaded Rectangle"
df_proto_counts = (
df_bad[wyk_col].map(get_isopointal_proto_from_aflow).value_counts().to_frame()
)

df_proto_counts["MP occurrences"] = 0
mp_proto_counts = df_mp.isopointal_proto_from_aflow.value_counts()
for proto in df_proto_counts.index:
df_proto_counts.loc[proto, "MP occurrences"] = mp_proto_counts.get(proto, 0)

df_proto_counts = df_proto_counts.reset_index(names=proto_col)
styler = df_proto_counts.head(10).style.background_gradient(cmap="viridis")

df_to_svelte_table(styler, f"{FIGS}/proto-counts-{model}-failures.svelte")


# %%
fig = spacegroup_sunburst(df_bad[spg_col], width=350, height=350)
fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
fig.show()
# save_fig(fig, f"{ROOT}/tmp/figures/spacegroup-sunburst-{model}-failures.png", scale=3)
save_fig(fig, f"{FIGS}/spacegroup-sunburst-{model}-failures.svelte")


# %%
fig = ptable_heatmap_plotly(df_bad.formula)
fig.layout.title = f"Elements in {title}"
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
save_fig(fig, f"{out_dir}/bad-{model}-elements.png", scale=3)
fig.show()
save_fig(fig, f"{ROOT}/tmp/figures/elements-{model.lower()}-failures.pdf")
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ dependencies = [
"matplotlib",
"pymatgen",
"numpy",
"pandas",
# output_formatting needed to for pandas Stylers
# see https://github.com/pandas-dev/pandas/blob/main/pyproject.toml#L78
"pandas[output_formatting]",
"scikit-learn",
"scipy",
"plotly",
Expand Down
Loading

0 comments on commit 01e88f0

Please # to comment.