From 8cebe12dcac337cbb174a76826eb9ec7b9496696 Mon Sep 17 00:00:00 2001 From: Francois Colleoni Date: Mon, 10 Jun 2024 10:33:26 +0200 Subject: [PATCH] FIX: Fix generate baseline --- Makefile | 2 +- .../{gen_baseline.py => generate_baseline.py} | 39 ++++++++++++------- 2 files changed, 26 insertions(+), 15 deletions(-) rename smash/tests/{gen_baseline.py => generate_baseline.py} (79%) diff --git a/Makefile b/Makefile index 76da3553..83beb661 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ test-coverage: #% Generate baseline for test with args (see argparser in gen_baseline.py) test-baseline: - cd smash/tests ; python3 gen_baseline.py + cd smash/tests ; python3 generate_baseline.py #% Format Python files with ruff and Fortran files with fprettify format: diff --git a/smash/tests/gen_baseline.py b/smash/tests/generate_baseline.py similarity index 79% rename from smash/tests/gen_baseline.py rename to smash/tests/generate_baseline.py index 7181c4d4..6aedfe91 100644 --- a/smash/tests/gen_baseline.py +++ b/smash/tests/generate_baseline.py @@ -15,6 +15,10 @@ import smash from smash._constant import STRUCTURE +sys.path.insert(0, "") +# Change current directory to smash/smash +os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)) + def parser(): parser = argparse.ArgumentParser() @@ -56,9 +60,9 @@ def adjust_module_names(module_names: list[str]) -> list[str]: pattern = re.compile("|".join(rep.keys())) - ret = ["smash.tests." + pattern.sub(lambda m: rep[re.escape(m.group(0))], name) for name in module_names] + ret = [pattern.sub(lambda m: rep[re.escape(m.group(0))], name) for name in module_names] - ret.remove("smash.tests.test_define_global_vars") + ret.remove("tests.test_define_global_vars") return ret @@ -128,9 +132,9 @@ def compare_baseline(f: h5py.File, new_f: h5py.File): df["TEST NAME" + (max_len_name - 8) * " "] = test_name df["STATUS"] = status - df.to_csv("diff_baseline.csv", sep="|", index=False) + df.to_csv("tests/diff_baseline.csv", sep="|", index=False) - os.system('echo "$(git show --no-patch)\n\n$(cat diff_baseline.csv)" > diff_baseline.csv') + os.system('echo "$(git show --no-patch)\n\n$(cat tests/diff_baseline.csv)" > tests/diff_baseline.csv') if __name__ == "__main__": @@ -152,6 +156,9 @@ def compare_baseline(f: h5py.File, new_f: h5py.File): model = smash.Model(setup, mesh) + # Do not need to read prcp and pet again + setup["read_prcp"] = False + setup["read_pet"] = False model_structure = [] for structure in STRUCTURE: @@ -160,19 +167,23 @@ def compare_baseline(f: h5py.File, new_f: h5py.File): setup["hydrological_module"], setup["routing_module"], ) = structure.split("-") - model_structure.append(smash.Model(setup, mesh)) - - # % Enable stdout + wmodel = smash.Model(setup, mesh) + wmodel.atmos_data.prcp = model.atmos_data.prcp + wmodel.atmos_data.pet = model.atmos_data.pet + if "ci" in wmodel.rr_parameters.keys: + wmodel.set_rr_parameters("ci", model.get_rr_parameters("ci")) + model_structure.append(wmodel) + + # # % Enable stdout sys.stdout = sys.__stdout__ - - module_names = sorted(glob.glob("**/test_*.py", recursive=True)) + module_names = sorted(glob.glob("tests/**/test_*.py", recursive=True)) module_names = adjust_module_names(module_names) - if os.path.exists("new_baseline.hdf5"): - os.remove("new_baseline.hdf5") + if os.path.exists("tests/new_baseline.hdf5"): + os.remove("tests/new_baseline.hdf5") - with h5py.File("new_baseline.hdf5", "w") as f: + with h5py.File("tests/new_baseline.hdf5", "w") as f: for mn in module_names: print(mn, end=" ") module = importlib.import_module(mn) @@ -201,8 +212,8 @@ def compare_baseline(f: h5py.File, new_f: h5py.File): print(".", end="", flush=True) print("") - baseline = h5py.File("baseline.hdf5") - new_baseline = h5py.File("new_baseline.hdf5") + baseline = h5py.File("tests/baseline.hdf5") + new_baseline = h5py.File("tests/new_baseline.hdf5") compare_baseline(baseline, new_baseline)