Skip to content

Commit

Permalink
refactor!: more robust plot_panel
Browse files Browse the repository at this point in the history
Changed argument names

refactor: extract another function

refactor: complete plot_panel refactor

refactor: rename plotting utils module

refactor: add value boundary check

refactor: add boundary checks
  • Loading branch information
baggiponte committed Jun 9, 2024
1 parent 4ddab23 commit 0848504
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 45 deletions.
161 changes: 161 additions & 0 deletions functime/_utils_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import plotly.graph_objects as go
import polars as pl
from plotly.subplots import make_subplots

if TYPE_CHECKING:
from typing import Optional, Tuple

COLOR_PALETTE = {"actual": "#B7B7B7", "forecast": "#1b57f1", "backtest": "#A76EF4"}


def make_figure(
*,
n_rows: int,
n_cols: int,
entities_chosen: pl.Series,
default_title: str,
**kwargs,
) -> go.Figure:
height = kwargs.pop("height", n_rows * 200)
template = kwargs.pop("template", "plotly_white")
title = kwargs.pop("title", default_title)

layout = go.Layout(
template=template,
height=height,
title=title,
**kwargs,
)

return make_subplots(
figure=go.Figure(layout=layout),
rows=n_rows,
cols=n_cols,
subplot_titles=entities_chosen,
)


def add_traces(
*,
figure: go.Figure,
y: pl.DataFrame,
name: str,
color: str,
n_cols: int,
):
entity_col, time_col, target_col = y.columns[:3]
for i, (_, ts) in enumerate(y.groupby([entity_col])):
row, col = get_subplot_grid_position(element=i, num_cols=n_cols)

figure.add_trace(
go.Scatter(
x=ts.get_column(time_col),
y=ts.get_column(target_col),
name=name,
legendgroup=name,
line=dict(color=COLOR_PALETTE[color]),
showlegend=False,
),
row=row,
col=col,
)


def get_chosen_entities(
*,
y: pl.LazyFrame,
n_series,
seed: Optional[int],
) -> pl.Series:
entity_col = y.columns[0]

entities = (
y.select(pl.col(entity_col).unique(maintain_order=True)).collect().to_series()
)

if seed is None:
entities_chosen = entities.slice(0, n_series)
else:
entities_chosen = entities.sample(n_series, seed=seed)
return entities_chosen


def get_num_points(
*,
entities_counts: pl.LazyFrame,
num_points: Optional[int],
) -> int:
if num_points is not None and num_points <= 0:
raise ValueError("Number of points must be a positive integer")

entity_col = entities_counts.columns[0]

min_points = (
entities_counts.select(pl.col(entity_col).struct.field("count").min())
.collect()
.item()
)

if num_points is None:
return min(min_points, 64)
else:
if num_points > min_points:
raise ValueError(
f"Number of points ({num_points}) is less than minimum number of points ({min_points})"
)
return num_points


def get_num_series(
*,
entities_counts: pl.LazyFrame,
num_series: Optional[int],
) -> int:
entity_col = entities_counts.columns[0]

num_entities = (
entities_counts.select(pl.col(entity_col).struct.field(entity_col).n_unique())
.collect()
.item()
)

if num_series is None:
return min(num_entities, 10)
else:
if num_series > num_entities:
raise ValueError(
f"Number of entities ({num_entities}) is less than number of series ({num_series})"
)
return num_series


def get_num_rows(
*,
num_series: int,
num_cols: int,
) -> int:
if num_cols < 1:
raise ValueError("Number of columns must be a positive integer")

if num_series < 1:
raise ValueError("Number of series must be a positive integer")

num_rows = num_series // num_cols
if num_series % num_cols != 0:
num_rows += 1

return num_rows


def get_subplot_grid_position(
*,
element: int,
num_cols: int,
) -> Tuple[int, int]:
row_index = element // num_cols + 1
col_index = element % num_cols + 1
return row_index, col_index
98 changes: 53 additions & 45 deletions functime/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@
import polars as pl
from plotly.subplots import make_subplots

from functime._utils_plotting import (
add_traces,
get_chosen_entities,
get_num_points,
get_num_rows,
get_num_series,
make_figure,
)
from functime.base.metric import METRIC_TYPE
from functime.metrics import smape

if TYPE_CHECKING:
from typing import Optional, Union

COLOR_PALETTE = {"actual": "#B7B7B7", "forecast": "#1b57f1", "backtest": "#A76EF4"}
DEFAULT_LAST_N = 64

Expand All @@ -29,7 +38,7 @@ def plot_entities(
y : pl.DataFrame | pl.LazyFrame
Panel DataFrame of observed values.
**kwargs
Additional keyword arguments to pass to `plotly.graph_objects.Figure.update_layout`.
Additional keyword arguments to pass to `plotly.graph_objects.Figure.update_layout` or, equivalently, a `plotly.graph_objects.Layout` object.
Returns
-------
Expand Down Expand Up @@ -61,13 +70,14 @@ def plot_entities(
)


# TODO: if num_points is (0,1] than take a percentage of the points
def plot_panel(
y: Union[pl.DataFrame, pl.LazyFrame],
*,
n_series: int = 10,
seed: int | None = None,
n_cols: int = 2,
last_n: int = DEFAULT_LAST_N,
num_series: Optional[int] = None,
num_points: Optional[int] = None,
num_cols: Optional[int] = None,
seed: Optional[int] = None,
**kwargs,
):
"""Given panel DataFrames of observed values `y`,
Expand All @@ -77,68 +87,66 @@ def plot_panel(
----------
y : Union[pl.DataFrame, pl.LazyFrame]
Panel DataFrame of observed values.
n_series : int
num_series : Optional[int]
Number of entities / time-series to plot.
Defaults to 10.
seed : int | None
Random seed for sampling entities / time-series.
Defaults to None.
n_cols : int
Defaults to 10 or the number of entities in the data, if less.
num_points : Optional[int]
Plot `last_n` most recent values in `y` and `y_pred`.
Defaults to 64 or the number of points in the shortest entity in the data.
num_cols : Optional[int]
Number of columns to arrange subplots.
Defaults to 2.
last_n : int
Plot `last_n` most recent values in `y` and `y_pred`.
Defaults to 64.
seed : Optional[int]
Random seed for sampling entities / time-series.
Defaults to None.
**kwargs
Additional keyword arguments to pass to `plotly.graph_objects.Figure.update_layout` or, equivalently, a `plotly.graph_objects.Layout` object.
Returns
-------
figure : plotly.graph_objects.Figure
Plotly subplots.
"""
entity_col, time_col, target_col = y.columns[:3]

if isinstance(y, pl.DataFrame):
y = y.lazy()

entities = y.select(pl.col(entity_col).unique(maintain_order=True)).collect()
entity_col = y.columns[0]
entities_counts = y.select(pl.col(entity_col).value_counts())

entities_sample = entities.to_series().sample(n_series, seed=seed)
n_cols = num_cols or 2
n_series = get_num_series(entities_counts=entities_counts, num_series=num_series)
n_rows = get_num_rows(num_series=n_series, num_cols=n_cols)

n_points = get_num_points(entities_counts=entities_counts, num_points=num_points)

entities_chosen = get_chosen_entities(y=y, n_series=n_series, seed=seed)

# Get most recent observations
y = (
y.filter(pl.col(entity_col).is_in(entities_sample))
y.filter(pl.col(entity_col).is_in(entities_chosen))
.group_by(entity_col)
.tail(last_n)
.tail(n_points)
.collect()
)

# Organize subplots
n_rows = n_series // n_cols
row_idx = np.repeat(range(n_rows), n_cols)
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=entities)
title = f"Line plot of the last {n_points} points of {n_series} series"

for i, entity_id in enumerate(entities):
ts = y.filter(pl.col(entity_col) == entity_id)
row = row_idx[i] + 1
col = i % n_cols + 1
# Plot actual
fig.add_trace(
go.Scatter(
x=ts.get_column(time_col),
y=ts.get_column(target_col),
name="Time-series",
legendgroup="Time-series",
line=dict(color=COLOR_PALETTE["forecast"]),
),
row=row,
col=col,
)
figure = make_figure(
n_rows=n_rows,
n_cols=n_cols,
entities_chosen=entities_chosen,
default_title=title,
**kwargs,
)

template = kwargs.pop("template", "plotly_white")
add_traces(
figure=figure,
y=y,
name="Time-series",
color="forecast",
n_cols=n_cols,
)

fig.update_layout(template=template, **kwargs)
fig = _remove_legend_duplicates(fig)
return fig
return figure


def plot_forecasts(
Expand Down

0 comments on commit 0848504

Please # to comment.