Skip to content

Commit

Permalink
[FIX] Contiguity on exogenous (#591)
Browse files Browse the repository at this point in the history
Co-authored-by: José Morales <jmoralz92@gmail.com>
  • Loading branch information
elephaint and jmoralez authored Jan 20, 2025
1 parent 9e067b4 commit deed23b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 26 deletions.
52 changes: 40 additions & 12 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -798,22 +798,29 @@
" payload: dict[str, Any],\n",
" multithreaded_compress: bool,\n",
" ) -> dict[str, Any]:\n",
" def ensure_contiguous_if_array(x):\n",
" if not isinstance(x, np.ndarray):\n",
" return x\n",
" if np.issubdtype(x.dtype, np.floating):\n",
" x = np.nan_to_num(\n",
" np.ascontiguousarray(x, dtype=np.float32),\n",
" nan=np.nan,\n",
" posinf=np.finfo(np.float32).max,\n",
" neginf=np.finfo(np.float32).min,\n",
" copy=False,\n",
" )\n",
" else:\n",
" x = np.ascontiguousarray(x)\n",
" return x\n",
"\n",
" def ensure_contiguous_arrays(d: dict[str, Any]) -> None:\n",
" for k, v in d.items():\n",
" if isinstance(v, np.ndarray):\n",
" if np.issubdtype(v.dtype, np.floating):\n",
" v_cont = np.ascontiguousarray(v, dtype=np.float32)\n",
" d[k] = np.nan_to_num(\n",
" v_cont, \n",
" nan=np.nan, \n",
" posinf=np.finfo(np.float32).max, \n",
" neginf=np.finfo(np.float32).min,\n",
" copy=False,\n",
" )\n",
" else:\n",
" d[k] = np.ascontiguousarray(v)\n",
" d[k] = ensure_contiguous_if_array(v)\n",
" elif isinstance(v, list):\n",
" d[k] = [ensure_contiguous_if_array(x) for x in v] \n",
" elif isinstance(v, dict):\n",
" ensure_contiguous_arrays(v) \n",
" ensure_contiguous_arrays(v)\n",
"\n",
" ensure_contiguous_arrays(payload)\n",
" content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)\n",
Expand Down Expand Up @@ -3369,6 +3376,27 @@
" h=7,\n",
" add_history=True,\n",
" num_partitions=2,\n",
" )\n",
" df_freq[\"exog_1\"] = 1\n",
" test_num_partitions_same_results(\n",
" nixtla_client.detect_anomalies,\n",
" level=98,\n",
" df=df_freq,\n",
" num_partitions=2,\n",
" )\n",
" test_num_partitions_same_results(\n",
" nixtla_client.cross_validation,\n",
" h=7,\n",
" n_windows=2,\n",
" df=df_freq,\n",
" num_partitions=2,\n",
" )\n",
" test_num_partitions_same_results(\n",
" nixtla_client.forecast,\n",
" df=df_freq,\n",
" h=7,\n",
" add_history=True,\n",
" num_partitions=2,\n",
" )"
]
},
Expand Down
2 changes: 1 addition & 1 deletion nixtla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.6.5"
__version__ = "0.6.6"
__all__ = ["NixtlaClient"]
from .nixtla_client import NixtlaClient
29 changes: 18 additions & 11 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,20 +723,27 @@ def _make_request(
payload: dict[str, Any],
multithreaded_compress: bool,
) -> dict[str, Any]:
def ensure_contiguous_if_array(x):
if not isinstance(x, np.ndarray):
return x
if np.issubdtype(x.dtype, np.floating):
x = np.nan_to_num(
np.ascontiguousarray(x, dtype=np.float32),
nan=np.nan,
posinf=np.finfo(np.float32).max,
neginf=np.finfo(np.float32).min,
copy=False,
)
else:
x = np.ascontiguousarray(x)
return x

def ensure_contiguous_arrays(d: dict[str, Any]) -> None:
for k, v in d.items():
if isinstance(v, np.ndarray):
if np.issubdtype(v.dtype, np.floating):
v_cont = np.ascontiguousarray(v, dtype=np.float32)
d[k] = np.nan_to_num(
v_cont,
nan=np.nan,
posinf=np.finfo(np.float32).max,
neginf=np.finfo(np.float32).min,
copy=False,
)
else:
d[k] = np.ascontiguousarray(v)
d[k] = ensure_contiguous_if_array(v)
elif isinstance(v, list):
d[k] = [ensure_contiguous_if_array(x) for x in v]
elif isinstance(v, dict):
ensure_contiguous_arrays(v)

Expand Down
4 changes: 2 additions & 2 deletions settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author = Nixtla
author_email = business@nixtla.io
copyright = Nixtla Inc.
branch = main
version = 0.6.5
version = 0.6.6
min_python = 3.9
audience = Developers
language = English
Expand All @@ -17,7 +17,7 @@ license = apache2
status = 4
requirements = annotated-types httpx[zstd] orjson pandas pydantic>=1.10 tenacity tqdm utilsforecast>=0.2.8
dev_requirements = black datasetsforecast fire hierarchicalforecast ipywidgets jupyterlab nbdev neuralforecast numpy<2 plotly polars pre-commit pyreadr python-dotenv pyyaml setuptools<70 statsforecast tabulate
distributed_requirements = fugue[dask,ray,spark]>=0.8.7 pandas<2.2 ray<2.6.3
distributed_requirements = fugue[dask,ray,spark]>=0.8.7 dask<=2024.12.1 pandas<2.2 ray<2.6.3
plotting_requirements = utilsforecast[plotting]
date_extra_requirements = holidays
nbs_path = nbs
Expand Down

0 comments on commit deed23b

Please # to comment.