diff --git a/tests/test_examples.py b/tests/test_examples.py index ce6f84ddf..2b9980c67 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,7 +14,7 @@ from numpy.typing import NDArray as Array from scipy.special import logsumexp -from timemachine.constants import DEFAULT_KT, KCAL_TO_KJ +from timemachine.constants import DEFAULT_FF, DEFAULT_KT, KCAL_TO_KJ from timemachine.datasets import fetch_freesolv from timemachine.fe.free_energy import assert_deep_eq from timemachine.fe.utils import get_mol_name @@ -23,14 +23,20 @@ EXAMPLES_DIR = Path(__file__).parent.parent / "examples" -def hash_file(path: Path, chunk_size: int = 2048) -> str: - assert path.is_file(), f"{path!s} doesn't exist" +def hash_directory_files(dir_path: Path, extensions: list[str], chunk_size: int = 2048) -> str: + assert dir_path.is_dir(), f"{dir_path!s} doesn't exist" m = hashlib.sha256() - with open(path, "rb") as ifs: - chunk = ifs.read(chunk_size) - while len(chunk) > 0: - m.update(chunk) - chunk = ifs.read(chunk_size) + assert len(extensions) > 0 + for directory, _, files in dir_path.walk(): + for f_name in sorted(files): + path = directory / f_name + if path.suffix not in extensions: + continue + with open(path, "rb") as ifs: + chunk = ifs.read(chunk_size) + while len(chunk) > 0: + m.update(chunk) + chunk = ifs.read(chunk_size) return m.hexdigest() @@ -248,20 +254,13 @@ def test_run_rbfe_legs( mol_b, seed, ): - # Can generate hashes from CI artifacts - endstate_hashes = { - "vacuum": ( - "81fe2a16aa7eb89e6a05e1d4e162d24b57e4c0c3effe07f6ff548ca32618d8e6", - "9c05f938850f0a0f85643e76f6201e8519f9d1c5ceefdd380b677f394f0b35f1", - ), - "solvent": ( - "6d8b39f723727d556b47e0907e89f1afe9843cae0b2e6107a745d1c29f9a3c8d", - "d6ecd7973f6f2c9d70f9652364e6b926f5551dcca6be0514bda3f0cd3524ccd0", - ), - "complex": ( - "b3c18188bd8fe2475152b5c1c8b9d81799037f8c33685db5834943a140cc2988", - "931c428e50c47c2a5442f2fa027d9a4acaf140339db34e2972d8ec6d4f00b40b", - ), + # To update the leg result hashes, refer to the hashes generated from CI runs. + # The CI jobs produce an artifact for the results stored at ARTIFACT_DIR_NAME + # which can be used to investigate the results that generated the hash. + leg_result_hashes = { + "vacuum": "81fe2a16aa7eb89e6a05e1d4e162d24b57e4c0c3effe07f6ff548ca32618d8e6", + "solvent": "6d8b39f723727d556b47e0907e89f1afe9843cae0b2e6107a745d1c29f9a3c8d", + "complex": "b3c18188bd8fe2475152b5c1c8b9d81799037f8c33685db5834943a140cc2988", } with resources.as_file(resources.files("timemachine.datasets.fep_benchmark.hif2a")) as hif2a_dir: config = dict( @@ -274,8 +273,7 @@ def test_run_rbfe_legs( n_eq_steps=n_eq_steps, n_frames=n_frames, n_windows=n_windows, - # Use simple charges to avoid os-dependent charge differences - forcefield="smirnoff_1_1_0_sc.py", + forcefield=DEFAULT_FF, output_dir=f"{ARTIFACT_DIR_NAME}/rbfe_{mol_a}_{mol_b}_{leg}_{seed}", ) @@ -323,20 +321,17 @@ def verify_run(output_dir: Path): assert len(traj_data["coords"]) == n_frames assert len(traj_data["boxes"]) == n_frames - def verify_endstate_hashes(output_dir: Path): + def verify_leg_result_hashes(output_dir: Path): leg_dir = output_dir / leg - endstate_0_hash = hash_file(leg_dir / "lambda0_traj.npz") - endstate_1_hash = hash_file(leg_dir / "lambda1_traj.npz") - assert endstate_0_hash == endstate_hashes[leg][0] and endstate_1_hash == endstate_hashes[leg][1], ( - f"{endstate_0_hash} != {endstate_hashes[leg][0]} and/or {endstate_1_hash} != {endstate_hashes[leg][1]}" - ) + leg_results_hash = hash_directory_files(leg_dir, ["npz"]) + assert leg_results_hash == leg_result_hashes[leg] config_a = config.copy() config_a["output_dir"] = config["output_dir"] + "_a" proc = run_example("run_rbfe_legs.py", get_cli_args(config_a)) assert proc.returncode == 0 verify_run(Path(config_a["output_dir"])) - verify_endstate_hashes(Path(config_a["output_dir"])) + verify_leg_result_hashes(Path(config_a["output_dir"])) config_b = config.copy() config_b["output_dir"] = config["output_dir"] + "_b"