-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changed argument names refactor: extract another function refactor: complete plot_panel refactor refactor: rename plotting utils module refactor: add value boundary check refactor: add boundary checks
- Loading branch information
1 parent
4ddab23
commit 0848504
Showing
2 changed files
with
214 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import plotly.graph_objects as go | ||
import polars as pl | ||
from plotly.subplots import make_subplots | ||
|
||
if TYPE_CHECKING: | ||
from typing import Optional, Tuple | ||
|
||
COLOR_PALETTE = {"actual": "#B7B7B7", "forecast": "#1b57f1", "backtest": "#A76EF4"} | ||
|
||
|
||
def make_figure( | ||
*, | ||
n_rows: int, | ||
n_cols: int, | ||
entities_chosen: pl.Series, | ||
default_title: str, | ||
**kwargs, | ||
) -> go.Figure: | ||
height = kwargs.pop("height", n_rows * 200) | ||
template = kwargs.pop("template", "plotly_white") | ||
title = kwargs.pop("title", default_title) | ||
|
||
layout = go.Layout( | ||
template=template, | ||
height=height, | ||
title=title, | ||
**kwargs, | ||
) | ||
|
||
return make_subplots( | ||
figure=go.Figure(layout=layout), | ||
rows=n_rows, | ||
cols=n_cols, | ||
subplot_titles=entities_chosen, | ||
) | ||
|
||
|
||
def add_traces( | ||
*, | ||
figure: go.Figure, | ||
y: pl.DataFrame, | ||
name: str, | ||
color: str, | ||
n_cols: int, | ||
): | ||
entity_col, time_col, target_col = y.columns[:3] | ||
for i, (_, ts) in enumerate(y.groupby([entity_col])): | ||
row, col = get_subplot_grid_position(element=i, num_cols=n_cols) | ||
|
||
figure.add_trace( | ||
go.Scatter( | ||
x=ts.get_column(time_col), | ||
y=ts.get_column(target_col), | ||
name=name, | ||
legendgroup=name, | ||
line=dict(color=COLOR_PALETTE[color]), | ||
showlegend=False, | ||
), | ||
row=row, | ||
col=col, | ||
) | ||
|
||
|
||
def get_chosen_entities( | ||
*, | ||
y: pl.LazyFrame, | ||
n_series, | ||
seed: Optional[int], | ||
) -> pl.Series: | ||
entity_col = y.columns[0] | ||
|
||
entities = ( | ||
y.select(pl.col(entity_col).unique(maintain_order=True)).collect().to_series() | ||
) | ||
|
||
if seed is None: | ||
entities_chosen = entities.slice(0, n_series) | ||
else: | ||
entities_chosen = entities.sample(n_series, seed=seed) | ||
return entities_chosen | ||
|
||
|
||
def get_num_points( | ||
*, | ||
entities_counts: pl.LazyFrame, | ||
num_points: Optional[int], | ||
) -> int: | ||
if num_points is not None and num_points <= 0: | ||
raise ValueError("Number of points must be a positive integer") | ||
|
||
entity_col = entities_counts.columns[0] | ||
|
||
min_points = ( | ||
entities_counts.select(pl.col(entity_col).struct.field("count").min()) | ||
.collect() | ||
.item() | ||
) | ||
|
||
if num_points is None: | ||
return min(min_points, 64) | ||
else: | ||
if num_points > min_points: | ||
raise ValueError( | ||
f"Number of points ({num_points}) is less than minimum number of points ({min_points})" | ||
) | ||
return num_points | ||
|
||
|
||
def get_num_series( | ||
*, | ||
entities_counts: pl.LazyFrame, | ||
num_series: Optional[int], | ||
) -> int: | ||
entity_col = entities_counts.columns[0] | ||
|
||
num_entities = ( | ||
entities_counts.select(pl.col(entity_col).struct.field(entity_col).n_unique()) | ||
.collect() | ||
.item() | ||
) | ||
|
||
if num_series is None: | ||
return min(num_entities, 10) | ||
else: | ||
if num_series > num_entities: | ||
raise ValueError( | ||
f"Number of entities ({num_entities}) is less than number of series ({num_series})" | ||
) | ||
return num_series | ||
|
||
|
||
def get_num_rows( | ||
*, | ||
num_series: int, | ||
num_cols: int, | ||
) -> int: | ||
if num_cols < 1: | ||
raise ValueError("Number of columns must be a positive integer") | ||
|
||
if num_series < 1: | ||
raise ValueError("Number of series must be a positive integer") | ||
|
||
num_rows = num_series // num_cols | ||
if num_series % num_cols != 0: | ||
num_rows += 1 | ||
|
||
return num_rows | ||
|
||
|
||
def get_subplot_grid_position( | ||
*, | ||
element: int, | ||
num_cols: int, | ||
) -> Tuple[int, int]: | ||
row_index = element // num_cols + 1 | ||
col_index = element % num_cols + 1 | ||
return row_index, col_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters