Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

FIX: Fix generate baseline #203

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ test-coverage:

#% Generate baseline for test with args (see argparser in gen_baseline.py)
test-baseline:
cd smash/tests ; python3 gen_baseline.py
cd smash/tests ; python3 generate_baseline.py

#% Format Python files with ruff and Fortran files with fprettify
format:
Expand Down
39 changes: 25 additions & 14 deletions smash/tests/gen_baseline.py → smash/tests/generate_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import smash
from smash._constant import STRUCTURE

sys.path.insert(0, "")
# Change current directory to smash/smash
os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))


def parser():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -56,9 +60,9 @@ def adjust_module_names(module_names: list[str]) -> list[str]:

pattern = re.compile("|".join(rep.keys()))

ret = ["smash.tests." + pattern.sub(lambda m: rep[re.escape(m.group(0))], name) for name in module_names]
ret = [pattern.sub(lambda m: rep[re.escape(m.group(0))], name) for name in module_names]

ret.remove("smash.tests.test_define_global_vars")
ret.remove("tests.test_define_global_vars")

return ret

Expand Down Expand Up @@ -128,9 +132,9 @@ def compare_baseline(f: h5py.File, new_f: h5py.File):
df["TEST NAME" + (max_len_name - 8) * " "] = test_name
df["STATUS"] = status

df.to_csv("diff_baseline.csv", sep="|", index=False)
df.to_csv("tests/diff_baseline.csv", sep="|", index=False)

os.system('echo "$(git show --no-patch)\n\n$(cat diff_baseline.csv)" > diff_baseline.csv')
os.system('echo "$(git show --no-patch)\n\n$(cat tests/diff_baseline.csv)" > tests/diff_baseline.csv')


if __name__ == "__main__":
Expand All @@ -152,6 +156,9 @@ def compare_baseline(f: h5py.File, new_f: h5py.File):

model = smash.Model(setup, mesh)

# Do not need to read prcp and pet again
setup["read_prcp"] = False
setup["read_pet"] = False
model_structure = []

for structure in STRUCTURE:
Expand All @@ -160,19 +167,23 @@ def compare_baseline(f: h5py.File, new_f: h5py.File):
setup["hydrological_module"],
setup["routing_module"],
) = structure.split("-")
model_structure.append(smash.Model(setup, mesh))

# % Enable stdout
wmodel = smash.Model(setup, mesh)
wmodel.atmos_data.prcp = model.atmos_data.prcp
wmodel.atmos_data.pet = model.atmos_data.pet
if "ci" in wmodel.rr_parameters.keys:
wmodel.set_rr_parameters("ci", model.get_rr_parameters("ci"))
model_structure.append(wmodel)

# # % Enable stdout
sys.stdout = sys.__stdout__

module_names = sorted(glob.glob("**/test_*.py", recursive=True))
module_names = sorted(glob.glob("tests/**/test_*.py", recursive=True))

module_names = adjust_module_names(module_names)

if os.path.exists("new_baseline.hdf5"):
os.remove("new_baseline.hdf5")
if os.path.exists("tests/new_baseline.hdf5"):
os.remove("tests/new_baseline.hdf5")

with h5py.File("new_baseline.hdf5", "w") as f:
with h5py.File("tests/new_baseline.hdf5", "w") as f:
for mn in module_names:
print(mn, end=" ")
module = importlib.import_module(mn)
Expand Down Expand Up @@ -201,8 +212,8 @@ def compare_baseline(f: h5py.File, new_f: h5py.File):
print(".", end="", flush=True)
print("")

baseline = h5py.File("baseline.hdf5")
new_baseline = h5py.File("new_baseline.hdf5")
baseline = h5py.File("tests/baseline.hdf5")
new_baseline = h5py.File("tests/new_baseline.hdf5")

compare_baseline(baseline, new_baseline)

Expand Down