Skip to content

Commit

Permalink
PR Feedback: Hash directory results
Browse files Browse the repository at this point in the history
* Enforces that dG predictions match
* Switch to default forcefield to match standard behavior
  • Loading branch information
badisa committed Feb 21, 2025
1 parent 23b23a3 commit 673dd18
Showing 1 changed file with 26 additions and 31 deletions.
57 changes: 26 additions & 31 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpy.typing import NDArray as Array
from scipy.special import logsumexp

from timemachine.constants import DEFAULT_KT, KCAL_TO_KJ
from timemachine.constants import DEFAULT_FF, DEFAULT_KT, KCAL_TO_KJ
from timemachine.datasets import fetch_freesolv
from timemachine.fe.free_energy import assert_deep_eq
from timemachine.fe.utils import get_mol_name
Expand All @@ -23,14 +23,20 @@
EXAMPLES_DIR = Path(__file__).parent.parent / "examples"


def hash_file(path: Path, chunk_size: int = 2048) -> str:
assert path.is_file(), f"{path!s} doesn't exist"
def hash_directory_files(dir_path: Path, extensions: list[str], chunk_size: int = 2048) -> str:
assert dir_path.is_dir(), f"{dir_path!s} doesn't exist"
m = hashlib.sha256()
with open(path, "rb") as ifs:
chunk = ifs.read(chunk_size)
while len(chunk) > 0:
m.update(chunk)
chunk = ifs.read(chunk_size)
assert len(extensions) > 0
for directory, _, files in dir_path.walk():
for f_name in sorted(files):
path = directory / f_name
if path.suffix not in extensions:
continue
with open(path, "rb") as ifs:
chunk = ifs.read(chunk_size)
while len(chunk) > 0:
m.update(chunk)
chunk = ifs.read(chunk_size)
return m.hexdigest()


Expand Down Expand Up @@ -248,20 +254,13 @@ def test_run_rbfe_legs(
mol_b,
seed,
):
# Can generate hashes from CI artifacts
endstate_hashes = {
"vacuum": (
"81fe2a16aa7eb89e6a05e1d4e162d24b57e4c0c3effe07f6ff548ca32618d8e6",
"9c05f938850f0a0f85643e76f6201e8519f9d1c5ceefdd380b677f394f0b35f1",
),
"solvent": (
"6d8b39f723727d556b47e0907e89f1afe9843cae0b2e6107a745d1c29f9a3c8d",
"d6ecd7973f6f2c9d70f9652364e6b926f5551dcca6be0514bda3f0cd3524ccd0",
),
"complex": (
"b3c18188bd8fe2475152b5c1c8b9d81799037f8c33685db5834943a140cc2988",
"931c428e50c47c2a5442f2fa027d9a4acaf140339db34e2972d8ec6d4f00b40b",
),
# To update the leg result hashes, refer to the hashes generated from CI runs.
# The CI jobs produce an artifact for the results stored at ARTIFACT_DIR_NAME
# which can be used to investigate the results that generated the hash.
leg_result_hashes = {
"vacuum": "81fe2a16aa7eb89e6a05e1d4e162d24b57e4c0c3effe07f6ff548ca32618d8e6",
"solvent": "6d8b39f723727d556b47e0907e89f1afe9843cae0b2e6107a745d1c29f9a3c8d",
"complex": "b3c18188bd8fe2475152b5c1c8b9d81799037f8c33685db5834943a140cc2988",
}
with resources.as_file(resources.files("timemachine.datasets.fep_benchmark.hif2a")) as hif2a_dir:
config = dict(
Expand All @@ -274,8 +273,7 @@ def test_run_rbfe_legs(
n_eq_steps=n_eq_steps,
n_frames=n_frames,
n_windows=n_windows,
# Use simple charges to avoid os-dependent charge differences
forcefield="smirnoff_1_1_0_sc.py",
forcefield=DEFAULT_FF,
output_dir=f"{ARTIFACT_DIR_NAME}/rbfe_{mol_a}_{mol_b}_{leg}_{seed}",
)

Expand Down Expand Up @@ -323,20 +321,17 @@ def verify_run(output_dir: Path):
assert len(traj_data["coords"]) == n_frames
assert len(traj_data["boxes"]) == n_frames

def verify_endstate_hashes(output_dir: Path):
def verify_leg_result_hashes(output_dir: Path):
leg_dir = output_dir / leg
endstate_0_hash = hash_file(leg_dir / "lambda0_traj.npz")
endstate_1_hash = hash_file(leg_dir / "lambda1_traj.npz")
assert endstate_0_hash == endstate_hashes[leg][0] and endstate_1_hash == endstate_hashes[leg][1], (
f"{endstate_0_hash} != {endstate_hashes[leg][0]} and/or {endstate_1_hash} != {endstate_hashes[leg][1]}"
)
leg_results_hash = hash_directory_files(leg_dir, ["npz"])
assert leg_results_hash == leg_result_hashes[leg]

config_a = config.copy()
config_a["output_dir"] = config["output_dir"] + "_a"
proc = run_example("run_rbfe_legs.py", get_cli_args(config_a))
assert proc.returncode == 0
verify_run(Path(config_a["output_dir"]))
verify_endstate_hashes(Path(config_a["output_dir"]))
verify_leg_result_hashes(Path(config_a["output_dir"]))

config_b = config.copy()
config_b["output_dir"] = config["output_dir"] + "_b"
Expand Down

0 comments on commit 673dd18

Please # to comment.