diff --git a/FOX/io/hdf5_utils.py b/FOX/io/hdf5_utils.py index 16fdc483..39fd4e50 100644 --- a/FOX/io/hdf5_utils.py +++ b/FOX/io/hdf5_utils.py @@ -126,7 +126,7 @@ def create_hdf5(filename: PathType, armc: ARMC) -> None: pd_dict = { 'aux_error': pd.Series(np.nan, index=aux_error_idx, name='aux_error'), 'aux_error_mod': pd.Series(np.nan, index=idx, name='aux_error_mod'), - 'param': armc.param.param[0], + 'param': armc.param.param, 'phi': pd.Series(np.nan, index=np.arange(kappa), name='phi'), } pd_dict2 = { @@ -811,6 +811,9 @@ def dset_to_df(f: File, key: str) -> Union[pd.DataFrame, List[pd.DataFrame]]: else: data = f[key][:] + if key == "/param": + data = np.swapaxes(data, 1, 2) + # Return a DataFrame or list of DataFrames if data.ndim == 2: return pd.DataFrame(data, index=index, columns=columns) diff --git a/FOX/testing_utils.py b/FOX/testing_utils.py index 2c18f1d2..c8385e99 100644 --- a/FOX/testing_utils.py +++ b/FOX/testing_utils.py @@ -16,6 +16,7 @@ """ +import re import warnings from types import MappingProxyType from os import listdir @@ -256,8 +257,9 @@ def load_results(workdir, *, result_type=CP2KMM_Result, n=1): # noqa: E302 ret = [] tmp = [] i = -n + file_pattern = re.compile(r"md(\.([0-9]+))?") for jobname in sorted(listdir(workdir_path)): - if jobname == '__pycache__': + if file_pattern.fullmatch(jobname) is None: continue plams_dir = workdir_path / jobname diff --git a/tests/test_hdf5_utils.py b/tests/test_hdf5_utils.py index 435845b2..cb90db90 100644 --- a/tests/test_hdf5_utils.py +++ b/tests/test_hdf5_utils.py @@ -235,6 +235,6 @@ def test_from_hdf5(): assertion.eq(hdf5_dict['aux_error'], out['aux_error']['rdf.0'][0]) assertion.eq(hdf5_dict['phi'], out['aux_error_mod'].values[:, -1]) assertion.eq(hdf5_dict['phi'][0], out['phi'].loc[0, 0]) - np.testing.assert_allclose(hdf5_dict['param'], out['param'].values) + np.testing.assert_allclose(hdf5_dict['param'], out['param'][0].T) np.testing.assert_allclose(hdf5_dict['rdf.0'], out['rdf.0'][0].values) np.testing.assert_array_equal(PARAM_METADATA, out['param_metadata'])