From 3cea049ef9e9a3d1b012f8eb5268d2b78493bc1c Mon Sep 17 00:00:00 2001 From: baggiponte <57922983+baggiponte@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:32:17 +0200 Subject: [PATCH] docs: add docstrings docs: add top level module docstring docs: update docstrings --- functime/_plotting.py | 280 +++++++++++++++++++++++++++++++++++++---- functime/plotting.py | 32 +++-- tests/test_plotting.py | 4 +- 3 files changed, 274 insertions(+), 42 deletions(-) diff --git a/functime/_plotting.py b/functime/_plotting.py index 0abdd539..c46cfd92 100644 --- a/functime/_plotting.py +++ b/functime/_plotting.py @@ -1,3 +1,5 @@ +"""TimeSeriesDisplay class to draw panel datasets plots.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -8,9 +10,7 @@ if TYPE_CHECKING: from typing import ( - Any, ClassVar, - Dict, Optional, Sequence, Tuple, @@ -49,13 +49,27 @@ def __init__( num_series: int, num_rows: int, default_title: str, - layout_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, ): """Initialize a time series display. - The initialisation defines a `plotly.graphic_objects.Figure` figure with the given + The initialisation defines a `plotly.graphic_objects.Figure` figure with subplots. + + Parameters + ---------- + entities : pl.Series | Sequence[str] + Entities to plot in the subplot grid. + num_cols : int + Number of columns in the subplot grid. + num_series : int + Number of series in the subplot grid. + num_rows : int + Number of rows in the subplot grid. + default_title : str + Default title for the figure. + kwargs + Additional keyword arguments to pass to `plotly.graph_objects.Layout` object. - Args: """ @@ -63,7 +77,6 @@ def __init__( self.num_series = num_series self.num_rows = num_rows self.entities = entities - kwargs = layout_kwargs or {} height = kwargs.pop("height", num_rows * 200) template = kwargs.pop("template", "plotly_white") @@ -94,6 +107,33 @@ def from_panel( default_title: str, **kwargs, ): + """Initialize a time series display from a Panel LazyFrame. + + The initialisation defines a `plotly.graphic_objects.Figure` figure with the given + number of columns and rows, and the entities in the data. + + Parameters + ---------- + y : pl.LazyFrame + Panel LazyFrame time series data. + num_cols : Optional[int] + Number of columns in the subplot grid. Defaults to 2. + num_series : Optional[int] + Number of series in the subplot grid. If `0`, plot all series. + If `None`, plot the smallest number between 10 and the total number of entities in the data. + Defaults to `None`. + seed : Optional[int] + Seed for the random sample of entities to plot. Defaults to `None`. + default_title : str + Default title for the figure. + **kwargs + Additional keyword arguments to pass to `plotly.graph_objects.Figure` object. + + Returns + ------- + self : Self + Instance of `TimeSeriesDisplay`. + """ n_cols = num_cols or 2 n_series = get_num_series( @@ -103,7 +143,7 @@ def from_panel( sample_entities = get_chosen_entities( y=y, - n_series=n_series, + num_series=n_series, seed=seed, ) @@ -123,11 +163,32 @@ def from_panel( def add_time_series( self: Self, + *, data: pl.LazyFrame, num_points: Optional[int] = None, - name: str = "Time-series", + name_on_hover: Optional[str] = None, **kwargs, ) -> Self: + """Add a time series to the subplot grid. + + Parameters + ---------- + data : pl.LazyFrame + Panel LazyFrame time series data. + num_points : Optional[int] + Number of data points to plot. + Defaults to 64 or the number of points in the shortest entity in the data. + If 0, plot all points. + name_on_hover : Optional[str] + Text that will be displayed on hover. Defaults to the name of the target column. + **kwargs + Additional keyword arguments to pass to `plotly.graph_objects.Line` object. + + Returns + ------- + self : Self + Instance of `TimeSeriesDisplay`. + """ entity_col = data.columns[0] if num_points == 0: # plot the whole series @@ -148,7 +209,7 @@ def add_time_series( self.figure = add_traces( figure=self.figure, y=y, - name=name, + name_on_hover=name_on_hover, num_cols=self.num_cols, **kwargs, ) @@ -160,21 +221,44 @@ def add_traces( *, figure: go.Figure, y: pl.DataFrame, - name: str, + name_on_hover: Optional[str] = None, num_cols: int, **kwargs, ) -> go.Figure: + """Add scatterplot traces to a `Figure` instance. + + The function needs to know the number of columns in the subplot grid to + place a trace in the correct position. + + Parameters + ---------- + figure : go.Figure + Plotly figure to add traces to. + y : pl.DataFrame + Panel DataFrame. + name_on_hover : Optional[str] + Text that will be displayed on hover. Defaults to the name of the target column, in title case. + num_cols : int + Number of columns in the subplot grid. + **kwargs + Additional keyword arguments to pass to `plotly.graph_objects.Line` object. + + Returns + ------- + figure : go.Figure + Updated Plotly figure. + """ entity_col, time_col, target_col = y.columns[:3] for i, (_, ts) in enumerate(y.groupby([entity_col])): - row, col = get_subplot_grid_position(element=i, num_cols=num_cols) + col, row = get_subplot_grid_position(element=i, num_cols=num_cols) figure.add_trace( go.Scatter( x=ts.get_column(time_col), y=ts.get_column(target_col), - name=name, - legendgroup=name, + name=name_on_hover or target_col.title(), + legendgroup=name_on_hover, line=kwargs, showlegend=False, ), @@ -188,19 +272,58 @@ def add_traces( def get_chosen_entities( *, y: pl.LazyFrame, - n_series: int, + num_series: int, seed: Optional[int] = None, ) -> pl.Series: + """Sample entities to plot in a subplot grid, given the data. + + The function checks whether the `n_series` is bigger than the total number of entities in the data. If so, it raises an error. + + If `seed` is `None`, it returns the first `n_series` entities in the data. Alternatively, it returns a random sample. + + Parameters + ---------- + y : pl.LazyFrame + Panel LazyFrame time series data. + n_series : int + Number of series to sample. + seed : Optional[int] + Seed for the random sample. Defaults to `None`, i.e. no sampling. + + Returns + ------- + entities : pl.Series + Series of sampled entities. + + Raises + ------ + ValueError + If `n_series` is bigger than the total number of entities in the data. + + Example + ------- + >>> get_chosen_entities(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), n_series=2) + pl.Series(['a', 'b'], name='entity') + >>> get_chosen_entities(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), n_series=0) + pl.Series(['a', 'b', 'c', 'd', 'e'], name='entity') + """ entity_col = y.columns[0] + total_entities = y.select(pl.col(entity_col).n_unique()).collect().item() + + if num_series > total_entities: + raise ValueError( + f"Number of series ({num_series}) is greater than the total number of entities ({total_entities})" + ) + entities = ( y.select(pl.col(entity_col).unique(maintain_order=True)).collect().to_series() ) if seed is None: - return entities.slice(0, n_series) + return entities.slice(0, num_series) else: - return entities.sample(n_series, seed=seed) + return entities.sample(num_series, seed=seed) def get_num_points( @@ -208,6 +331,36 @@ def get_num_points( y: pl.LazyFrame, num_points: Optional[int], ) -> int: + """Get the number of data points to plot in a subplot grid, given the data. + + The function checks whether the `num_points` is bigger than the shortest entity in the data. If so, it raises an error. + + If `num_points` is `None`, it returns the smallest number between 64 and the shortest entity in the data. + + Parameters + ---------- + y : pl.LazyFrame + Panel LazyFrame time series data. + num_points : Optional[int] + Number of data points to plot. If `None`, returns the smallest number between 64 and the shortest entity in the data. + + Returns + ------- + int + Number of data points to plot. + + Raises + ------ + ValueError + If `num_points` is smaller than 0 or bigger than the shortest entity in the data. + + Example + ------- + >>> get_num_points(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), num_points=2) + 2 + >>> get_num_points(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), num_points=0) + 5 + """ if num_points is not None and num_points < 0: raise ValueError("Number of points must be 0 or greater") @@ -231,6 +384,36 @@ def get_num_series( y: pl.LazyFrame, num_series: Optional[int], ) -> int: + """Get the number of series to plot in a subplot grid, given the data. + + The function checks whether the `num_series` is bigger than the total number of entities in the data. If so, it raises an error. + + If `num_series` is `None`, it returns the smallest number between 10 and the total number of entities in the data. + + Parameters + ---------- + y : pl.LazyFrame + Panel LazyFrame time series data. + num_series : Optional[int] + Number of series to sample. If `None`, returns the smallest number between 10 and the total number of entities in the data. + + Returns + ------- + int + Number of series to draw in the subplot grid. + + Raises + ------ + ValueError + If `num_series` is 0 and `num_cols` is smaller than 1, or if `num_series` is bigger than the total number of entities in the data. + + Example + ------- + >>> get_num_series(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), num_series=2) + 2 + >>> get_num_series(y=pl.DataFrame({"entity": ["a", "b", "c", "d", "e"], "value": [1, 2, 3, 4, 5]}), num_series=0) + 5 + """ if num_series is not None and num_series < 0: raise ValueError("Number of series must be 0 or greater") @@ -252,16 +435,41 @@ def get_num_rows( num_series: int, num_cols: int, ) -> int: + """Get the number of rows in a subplot grid, given the number of series and number of columns. + + Parameters + ---------- + num_series : int + Number of series in the subplot grid. + num_cols : int + Number of columns in the subplot grid. + + Returns + ------- + int + Number of rows in the subplot grid. + + Raises + ------ + ValueError + If `num_series` is 0 and `num_cols` is smaller than 1. + + Example + ------- + >>> get_num_rows(num_series=10, num_cols=2) + 5 + >>> get_num_rows(num_series=5, num_cols=2) + 3 + """ if num_cols < 1: raise ValueError("Number of columns must be a positive integer") if num_series < 1: raise ValueError("Number of series must be a positive integer") - num_rows = num_series // num_cols - if num_series % num_cols != 0: - num_rows += 1 - + num_rows, remainder = divmod(num_series, num_cols) + if remainder != 0: + return num_rows + 1 return num_rows @@ -270,6 +478,32 @@ def get_subplot_grid_position( element: int, num_cols: int, ) -> Tuple[int, int]: - row_index = element // num_cols + 1 - col_index = element % num_cols + 1 - return row_index, col_index + """Get the row and column index of the subplot at the given element index. + + Need to add 1 because the grid indexes in a plotly subplot are 1-based. + + Parameters + ---------- + element : int + Element index in the subplot grid. + num_cols : int + Number of columns in the subplot grid. + + Returns + ------- + Tuple[int, int] + Row and column index of the subplot. + + Example + ------- + >>> get_subplot_grid_position(element=0, num_cols=2) + (1, 1) + >>> get_subplot_grid_position(element=1, num_cols=2) + (1, 2) + >>> get_subplot_grid_position(element=2, num_cols=2) + (2, 1) + >>> get_subplot_grid_position(element=3, num_cols=2) + (2, 2) + """ + col_index, row_index = divmod(element, num_cols) + return row_index + 1, col_index + 1 diff --git a/functime/plotting.py b/functime/plotting.py index 3dfb4848..f21eae67 100644 --- a/functime/plotting.py +++ b/functime/plotting.py @@ -31,8 +31,7 @@ def plot_entities( y : pl.DataFrame | pl.LazyFrame Panel DataFrame of observed values. **kwargs - Additional keyword arguments to pass to `plotly.graph_objects.Figure.update_layout` or, - equivalently, a `plotly.graph_objects.Layout` object. + Additional keyword arguments to pass to a `plotly.graph_objects.Layout` object. Returns ------- @@ -50,18 +49,20 @@ def plot_entities( title = kwargs.pop("title", "Entities counts") template = kwargs.pop("template", "plotly_white") + layout = go.Layout( + height=height, + title=title, + template=template, + **kwargs, + ) + return go.Figure( data=go.Bar( x=entity_counts.get_column("len"), y=entity_counts.get_column(entity_col), orientation="h", ), - layout=go.Layout( - height=height, - title=title, - template=template, - **kwargs, - ), + layout=layout, ) @@ -84,26 +85,24 @@ def plot_panel( y : Union[pl.DataFrame, pl.LazyFrame] Panel DataFrame of observed values. num_series : Optional[int] - Number of entities / time-series to plot. + Number of entities / time-series to plot. If 0, plot all entities. Defaults to 10 or the number of entities in the data, if less. num_points : Optional[int] - Plot `last_n` most recent values in `y` and `y_pred`. + Plot `last_n` most recent values in `y`. If 0, plot all points. Defaults to 64 or the number of points in the shortest entity in the data. num_cols : Optional[int] - Number of columns to arrange subplots. - Defaults to 2. + Number of columns to arrange subplots. Defaults to 2. seed : Optional[int] - Random seed for sampling entities / time-series. - Defaults to None. + Random seed for sampling entities / time-series. Defaults to `None`. layout_kwargs - Additional keyword arguments to pass to `plotly.graph_objects.Figure.update_layout` or, equivalently, a `plotly.graph_objects.Layout` object. + Additional keyword arguments to pass to a `plotly.graph_objects.Layout` object. line_kwargs Additional keyword arguments to pass to a `plotly.graph_objects.Line` object. Returns ------- figure : plotly.graph_objects.Figure - Plotly subplots. + Plotly instance of `Figure` with all the subplots. """ if isinstance(y, pl.DataFrame): @@ -121,7 +120,6 @@ def plot_panel( drawer.add_time_series( data=y, num_points=num_points, - name="Time-series", **line_kwargs or {"color": drawer.DEFAULT_PALETTE["primary"]}, ) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 90cbedfc..aee32a6c 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -23,7 +23,7 @@ def mock_dataframe(): def test_get_chosen_entities_not_random(mock_dataframe): - actual = get_chosen_entities(y=mock_dataframe, n_series=3, seed=None).to_list() + actual = get_chosen_entities(y=mock_dataframe, num_series=3, seed=None).to_list() expected = ["A", "B", "C"] assert actual == expected @@ -40,7 +40,7 @@ def test_get_chosen_entities_random(mock_dataframe, n_series, seed): actual = get_chosen_entities( y=mock_dataframe, - n_series=n_series, + num_series=n_series, seed=seed, ).to_list()