Skip to content

Commit

Permalink
ev: Batch data for faster evaluation. (#3051)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Nov 13, 2023
1 parent 37b2419 commit 56eec11
Showing 1 changed file with 59 additions and 34 deletions.
93 changes: 59 additions & 34 deletions src/gluonts/model/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import logging
from collections import ChainMap
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union
from dataclasses import dataclass
from toolz import first, valmap

Expand All @@ -23,10 +23,10 @@

from gluonts.dataset import DataEntry
from gluonts.dataset.split import TestData
from gluonts.time_feature.seasonality import get_seasonality
from gluonts.model import Forecast, Predictor
from gluonts.ev.ts_stats import seasonal_error
from gluonts.itertools import prod
from gluonts.itertools import batcher, prod
from gluonts.model import Forecast, Predictor
from gluonts.time_feature.seasonality import get_seasonality

logger = logging.getLogger(__name__)

Expand All @@ -39,53 +39,60 @@ class BatchForecast:
``gluonts.ev``.
"""

forecast: Forecast
forecasts: List[Forecast]
allow_nan: bool = False

def __getitem__(self, name):
value = self.forecast[name]
if np.isnan(value).any():
values = [forecast[name].T for forecast in self.forecasts]
res = np.stack(values, axis=0)

if np.isnan(res).any():
if not self.allow_nan:
raise ValueError("Forecast contains NaN values")

logger.warning(
"Forecast contains NaN values. Metrics may be incorrect."
)

return np.expand_dims(value.T, axis=0)
return res


def _get_data_batch(
input_: DataEntry,
label: DataEntry,
forecast: Forecast,
input_batch: List[DataEntry],
label_batch: List[DataEntry],
forecast_batch: List[Forecast],
seasonality: Optional[int] = None,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
) -> ChainMap:
forecast_dict = BatchForecast(forecast, allow_nan=allow_nan_forecast)

freq = forecast.start_date.freqstr
if seasonality is None:
seasonality = get_seasonality(freq=freq)

label_target = label["target"]
input_target = input_["target"]
label_target = np.stack([label["target"] for label in label_batch], axis=0)
if mask_invalid_label:
label_target = np.ma.masked_invalid(label_target)
input_target = np.ma.masked_invalid(input_target)

other_data = {
"label": np.expand_dims(label_target, axis=0),
"seasonal_error": np.expand_dims(
seasonal_error(
input_target, seasonality=seasonality, time_axis=-1
),
axis=0,
),
"label": label_target,
}

return ChainMap(other_data, forecast_dict) # type: ignore
seasonal_error_values = []
for input_ in input_batch:
seasonality_entry = seasonality
if seasonality_entry is None:
seasonality_entry = get_seasonality(input_["start"].freqstr)
input_target = input_["target"]
if mask_invalid_label:
input_target = np.ma.masked_invalid(input_target)
seasonal_error_values.append(
seasonal_error(
input_target,
seasonality=seasonality_entry,
time_axis=-1,
)
)
other_data["seasonal_error"] = np.array(seasonal_error_values)

return ChainMap(
other_data, BatchForecast(forecast_batch, allow_nan=allow_nan_forecast) # type: ignore
)


def evaluate_forecasts_raw(
Expand All @@ -94,6 +101,7 @@ def evaluate_forecasts_raw(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand Down Expand Up @@ -130,16 +138,26 @@ def evaluate_forecasts_raw(

index_data = []

for input_, label, forecast in tqdm(
zip(test_data.input, test_data.label, forecasts)
input_batches = batcher(test_data.input, batch_size=batch_size)
label_batches = batcher(test_data.label, batch_size=batch_size)
forecast_batches = batcher(forecasts, batch_size=batch_size)

pbar = tqdm()
for input_batch, label_batch, forecast_batch in zip(
input_batches, label_batches, forecast_batches
):
if 0 not in axis:
index_data.append((forecast.item_id, forecast.start_date))
index_data.extend(
[
(forecast.item_id, forecast.start_date)
for forecast in forecast_batch
]
)

data_batch = _get_data_batch(
input_,
label,
forecast,
input_batch,
label_batch,
forecast_batch,
seasonality=seasonality,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
Expand All @@ -148,6 +166,9 @@ def evaluate_forecasts_raw(
for evaluator in evaluators.values():
evaluator.update(data_batch)

pbar.update(len(forecast_batch))
pbar.close()

metrics_values = {
metric_name: evaluator.get()
for metric_name, evaluator in evaluators.items()
Expand All @@ -165,6 +186,7 @@ def evaluate_forecasts(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand All @@ -188,6 +210,7 @@ def evaluate_forecasts(
test_data=test_data,
metrics=metrics,
axis=axis,
batch_size=batch_size,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
seasonality=seasonality,
Expand Down Expand Up @@ -217,6 +240,7 @@ def evaluate_model(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand All @@ -242,6 +266,7 @@ def evaluate_model(
test_data=test_data,
metrics=metrics,
axis=axis,
batch_size=batch_size,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
seasonality=seasonality,
Expand Down

0 comments on commit 56eec11

Please # to comment.