Skip to content

Commit

Permalink
for #185: fix FileNotFoundError from Model.<key>.label when pip insta…
Browse files Browse the repository at this point in the history
…ll . (i.e. non-editable install)

- add 'build' directory to .gitignore
- fix pyproject.toml to include model YAML files in package data
- remove more matplotlib remnants with plotly in data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
  • Loading branch information
janosh committed Dec 25, 2024
1 parent fa0d56e commit a937a1a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 56 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# package install
*.egg-info
dist
build

# cache
__pycache__
Expand All @@ -14,7 +15,6 @@ __pycache__
data/**/raw
data/**/tsne
!data/mp/2023-02-07-mp-elemental-reference-entries.json.gz
models/**/checkpoints
data/**/*.extxyz*
data/**/*.json*
data/**/*.zip*
Expand All @@ -34,7 +34,6 @@ models/**/*.tgz
site/src/routes/api/*.md

# temporary ignore rules
models/cgcnn/2023-02-05-ens=10-perturb=5
data/mp/mptrj-gga-ggapu/*

# large files
Expand Down
93 changes: 49 additions & 44 deletions data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""NOTE MaterialsProject2020Compatibility takes structural information into account when
correcting energies (for oxides and sulfides only). Always use
ComputedStructureEntry, not ComputedEntry when applying energy corrections.
"""MaterialsProject2020Compatibility takes structural information into account when
correcting energies (for oxides and sulfides only). Always pass
ComputedStructureEntry, not ComputedEntry when calling process_entries.
"""

# %%
import gzip
import json

import matplotlib.pyplot as plt
import pandas as pd
import pymatviz as pmv
from pymatgen.entries.compatibility import (
Expand All @@ -28,11 +27,17 @@

cses = [
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_cse[Key.computed_structure_entry])
for dct in tqdm(
df_cse[Key.computed_structure_entry],
desc="Loading ComputedStructureEntries",
)
]

ces = [
ComputedEntry.from_dict(dct) for dct in tqdm(df_cse[Key.computed_structure_entry])
ComputedEntry.from_dict(dct)
for dct in tqdm(
df_cse[Key.computed_structure_entry], desc="Loading ComputedEntries"
)
]


Expand All @@ -43,10 +48,14 @@
assert len(processed) == len(df_cse)

df_wbm["e_form_per_atom_mp2020_from_ce"] = [
get_e_form_per_atom(entry) for entry in tqdm(ces)
get_e_form_per_atom(entry)
for entry in tqdm(ces, desc="Calculating formation energies from ComputedEntries")
]
df_wbm["e_form_per_atom_mp2020_from_cse"] = [
get_e_form_per_atom(entry) for entry in tqdm(cses)
get_e_form_per_atom(entry)
for entry in tqdm(
cses, desc="Calculating formation energies from ComputedStructureEntries"
)
]

df_wbm["mp2020_cse_correction_per_atom"] = [
Expand Down Expand Up @@ -79,50 +88,46 @@
df_wbm["anion"][df_wbm[Key.chem_sys].astype(str).str.contains("'O'")] = "oxide"
df_wbm["anion"][df_wbm[Key.chem_sys].astype(str).str.contains("'S'")] = "sulfide"

assert dict(df_wbm.anion.value_counts()) == {"oxide": 26984, "sulfide": 10606}
anion_counts = dict(df_wbm.anion.value_counts())
assert anion_counts == {"oxide": 26_984, "sulfide": 10_596}, f"{anion_counts=}"

df_ce_ne_cse = df_wbm.query(
"abs(e_form_per_atom_mp2020_from_cse - e_form_per_atom_mp2020_from_ce) > 1e-4"
)


# %%
ax = plt.gca()
for key, df_anion in df_ce_ne_cse.groupby("anion"):
ax = df_anion.plot.scatter(
ax=ax,
x="mp2020_cse_correction_per_atom",
y="mp2020_ce_correction_per_atom",
label=f"{key} ({len(df_anion):,})",
color=dict(oxide="orange", sulfide="teal").get(key, "blue"),
title=f"CSE vs CE corrections for ({len(df_ce_ne_cse):,} / {len(df_wbm):,} = "
f"{len(df_ce_ne_cse) / len(df_wbm):.1%})\n outliers of largest difference",
for x_col, y_col, title in (
("mp2020_cse_correction_per_atom", "mp2020_ce_correction_per_atom", "correction"),
("e_form_per_atom_mp2020_from_cse", "e_form_per_atom_mp2020_from_ce", "e-form"),
):
fig = df_ce_ne_cse.plot.scatter(
x=x_col,
y=y_col,
color="anion",
color_discrete_map={"oxide": "orange", "sulfide": "teal"},
hover_data=[Key.formula],
backend="plotly",
)

ax.axline((0, 0), slope=1, color="gray", linestyle="dashed", zorder=-1)

pmv.save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")


# %%
ax = plt.gca()
for key, df_anion in df_ce_ne_cse.groupby("anion"):
ax = df_anion.plot.scatter(
ax=ax,
x="e_form_per_atom_mp2020_from_cse",
y="e_form_per_atom_mp2020_from_ce",
label=f"{key} ({len(df_anion):,})",
color=dict(oxide="orange", sulfide="teal").get(key, "blue"),
title=f"Outliers in formation energy from CSE vs CE ({len(df_ce_ne_cse):,}"
f" / {len(df_wbm):,} = {len(df_ce_ne_cse) / len(df_wbm):.1%})",
)

ax.axline((0, 0), slope=1, color="gray", linestyle="dashed", zorder=-1)

# insight: all materials for which ComputedEntry and ComputedStructureEntry give
# different formation energies are oxides or sulfides for which MP 2020 compat takes
# into account structural information to make more accurate corrections.
pmv.save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-e-form-outliers.pdf")
title = f"CSE vs CE {title}<br>({len(df_ce_ne_cse):,} / {len(df_wbm):,} "
title += f"= {len(df_ce_ne_cse) / len(df_wbm):.1%})"
fig.layout.title.update(text=title, x=0.5, font=dict(size=16))
fig.layout.margin.t = 40
fig.layout.legend.update(x=0, title=None)

pmv.powerups.add_identity_line(fig)

# Update legend labels with counts
for trace in fig.data:
anion = trace.name
count = len(df_ce_ne_cse[df_ce_ne_cse.anion == anion])
trace.name = f"{anion} ({count:,})"

# insight: all materials for which ComputedEntry and ComputedStructureEntry give
# different formation energies are oxides or sulfides for which MP 2020 compat takes
# into account structural information to make more accurate corrections.
pmv.save_fig(fig, f"{ROOT}/tmp/{today}-ce-vs-cse-{title}-outliers.pdf")
fig.show()


# %% below code resulted in
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

PKG_DIR = os.path.dirname(__file__)
# repo root directory if editable install, else the pkg directory
ROOT = os.path.dirname(PKG_DIR) if pkg_is_editable else PKG_DIR
ROOT = os.path.dirname(PKG_DIR)
DATA_DIR = f"{ROOT}/data" # directory to store raw data
SITE_FIGS = f"{ROOT}/site/src/figs" # directory for interactive figures
# directory to write model analysis for website
Expand Down
10 changes: 3 additions & 7 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,6 @@ def __new__(
obj._base_dir = base_dir # noqa: SLF001
return obj

@property
def member_map(cls: type[T]) -> dict[str, "Files"]: # type: ignore[misc]
"""Map of member names to member objects."""
return cls._member_map_ # type: ignore[return-value]

@property
def base_dir(cls) -> str:
"""Base directory of the file."""
Expand Down Expand Up @@ -504,7 +499,8 @@ def phonons_path(self) -> str | None:
return f"{ROOT}/{rel_path}"


px.defaults.labels |= {k: v.label for k, v in Model.member_map.items()}
# render model keys as labels in plotly axes and legends
px.defaults.labels |= {k.name: k.label for k in Model}


def load_df_wbm_with_preds(
Expand Down Expand Up @@ -545,7 +541,7 @@ def load_df_wbm_with_preds(
valid_models = {model.name for model in Model}
if models == ():
models = tuple(valid_models)
inv_label_map = {v.label: k for k, v in Model.member_map.items()}
inv_label_map = {key.label: key.name for key in Model}
# map pretty model names back to Model enum keys
models = {inv_label_map.get(model, model) for model in models}
if unknown_models := ", ".join(models - valid_models):
Expand Down
2 changes: 1 addition & 1 deletion models/chgnet/join_chgnet_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@


# %%
ax = pmv.density_scatter_plotly(
pmv.density_scatter_plotly(
df=df_preds,
x=MbdKey.e_form_dft,
y=e_form_chgnet_col,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ include = ["matbench_discovery*"]
exclude = ["tests", "tests.*"]

[tool.setuptools.package-data]
matbench_discovery = ["**/*.yml"]
matbench_discovery = ["**/*.yml", "../models/**/*.yml"]

[build-system]
requires = ["setuptools>=70"]
Expand Down

0 comments on commit a937a1a

Please # to comment.