From 30f519f885ed22926602a372d584b3139fcfda52 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 14 Jun 2023 20:50:48 -0400 Subject: [PATCH 01/36] fixes --- mesmerize_viz/_cnmf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 5f009d4..1d46e53 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -526,7 +526,6 @@ def _change_data_gridplot( elif data_option == "heatmap": current_graphic = subplot.add_heatmap( data_array, - colors=component_colors, name="components", **graphic_kwargs ) @@ -558,9 +557,7 @@ def _change_data_gridplot( subplot.camera.maintain_aspect = False if len(self.linear_selectors) > 0: - self._synchronizer = Synchronizer( - *self.linear_selectors, key_bind=None - ) + self._synchronizer = Synchronizer(*self.linear_selectors) for ls in self.linear_selectors: ls.selection.add_event_handler(self.set_frame_index) From 80d5f2bb8829b825be0788ea4ae33224617cdc0c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 17 Jun 2023 02:16:20 -0400 Subject: [PATCH 02/36] aspect for temporal data, detect if in jupyter, events --- mesmerize_viz/_cnmf.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 1d46e53..d076cf1 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -3,6 +3,7 @@ import math import itertools from warnings import warn +from itertools import product import ipywidgets import numpy as np @@ -514,6 +515,9 @@ def _change_data_gridplot( # otherwise the plot has nothing in it which causes issues subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") + # scale according to temporal dims + subplot.camera.maintain_aspect = False + elif data_option == "temporal-stack": current_graphic = subplot.add_line_stack( data_array, @@ -523,6 +527,9 @@ def _change_data_gridplot( ) self.temporal_stack_graphics.append(current_graphic) + # scale according to temporal dims + subplot.camera.maintain_aspect = False + elif data_option == "heatmap": current_graphic = subplot.add_heatmap( data_array, @@ -531,6 +538,9 @@ def _change_data_gridplot( ) self.heatmap_graphics.append(current_graphic) + # scale according to temporal dims + subplot.camera.maintain_aspect = False + else: img_graphic = subplot.add_image( data_array[self._current_frame_index], @@ -600,6 +610,9 @@ def _connect_events(self): for temporal_graphic in self.temporal_graphics: contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) + for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): + cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) + def set_frame_index(self, ev): # 0 because this will return the same number repeated * n_components index = ev.pick_info["selected_index"][0] @@ -804,12 +817,19 @@ def show(self): self.show_all_checkbox = Checkbox(value=True, description="Show all components") + gridplots_widget = [gp.show() for gp in self.gridplots] + + if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: + vbox_elements = gridplots_widget + else: + vbox_elements = list() + widget = VBox( [ HBox([self.datagrid, self.params_text_area]), self.show_all_checkbox, HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), - VBox([gp.show() for gp in self.gridplots]) + VBox(vbox_elements) ] ) From e69f61546a56147090b133f7445dc5da679bb388 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 21 Jun 2023 03:35:10 -0400 Subject: [PATCH 03/36] split up cnmf into modules, start eval widget --- mesmerize_viz/_cnmf.py | 987 -------------------------- mesmerize_viz/_cnmf/__init__.py | 0 mesmerize_viz/_cnmf/_eval.py | 71 ++ mesmerize_viz/_cnmf/_extensions.py | 90 +++ mesmerize_viz/_cnmf/_time_array.py | 137 ++++ mesmerize_viz/_cnmf/_viz_container.py | 287 ++++++++ mesmerize_viz/_cnmf/_wrapper.py | 478 +++++++++++++ 7 files changed, 1063 insertions(+), 987 deletions(-) delete mode 100644 mesmerize_viz/_cnmf.py create mode 100644 mesmerize_viz/_cnmf/__init__.py create mode 100644 mesmerize_viz/_cnmf/_eval.py create mode 100644 mesmerize_viz/_cnmf/_extensions.py create mode 100644 mesmerize_viz/_cnmf/_time_array.py create mode 100644 mesmerize_viz/_cnmf/_viz_container.py create mode 100644 mesmerize_viz/_cnmf/_wrapper.py diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py deleted file mode 100644 index d076cf1..0000000 --- a/mesmerize_viz/_cnmf.py +++ /dev/null @@ -1,987 +0,0 @@ -from typing import * -from functools import partial -import math -import itertools -from warnings import warn -from itertools import product - -import ipywidgets -import numpy as np -import pandas as pd - -from mesmerize_core.arrays._base import LazyArray -from mesmerize_core.utils import quick_min_max - -from fastplotlib import ImageWidget, GridPlot, graphics -from fastplotlib.graphics.selectors import LinearSelector, Synchronizer -from fastplotlib.utils import calculate_gridshape - -from ipydatagrid import DataGrid -from ipywidgets import Textarea, VBox, HBox, Layout, Checkbox, IntSlider, BoundedIntText, jslink - -from ._utils import ZeroArray, format_params - - -# basic data options -VALID_DATA_OPTIONS = [ - "contours", - "empty" -] - -IMAGE_OPTIONS = [ - "input", - "rcm", - "rcb", - "residuals", - "corr", - "pnr", -] - -VALID_DATA_OPTIONS += IMAGE_OPTIONS - - -TEMPORAL_OPTIONS = [ - "temporal", - "temporal-stack", - "heatmap", -] - -VALID_DATA_OPTIONS += TEMPORAL_OPTIONS - -# RCM and RCB projections -rcm_rcb_proj_options = list() - -for option in ["rcm", "rcb"]: - for proj in ["mean", "min", "max", "std"]: - rcm_rcb_proj_options.append(f"{option}-{proj}") - -VALID_DATA_OPTIONS += rcm_rcb_proj_options -IMAGE_OPTIONS += rcm_rcb_proj_options - - -projs = [ - "mean", - "max", - "std", -] - -IMAGE_OPTIONS += projs - -VALID_DATA_OPTIONS += projs - - -class ExtensionCallWrapper: - def __init__(self, extension_func: callable, kwargs: dict = None, attr: str = None): - """ - Basically like ``functools.partial`` but supports kwargs. - - Parameters - ---------- - extension_func: callable - extension function reference - - kwargs: dict - kwargs to pass to the extension function when it is called - - attr: str, optional - return an attribute of the callable's output instead of the return value of the callable. - Example: if using rcm, can set ``attr="max_image"`` to return the max proj of the RCM. - """ - - if kwargs is None: - self.kwargs = dict() - else: - self.kwargs = kwargs - - self.func = extension_func - self.attr = attr - - def __call__(self, *args, **kwargs): - rval = self.func(**self.kwargs) - - if self.attr is not None: - return getattr(rval, self.attr) - - return rval - - -def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_data_loaders: dict = None) -> dict: - """ - Returns dict that maps data option str to a callable that can return the corresponding data array. - - For example, ``{"input": series.get_input_movie}`` maps "input" -> series.get_input_movie - - Parameters - ---------- - series: pd.Series - row/item to get mapping from - - data_kwargs: dict, optional - optional kwargs for each of the extension functions - - other_data_loaders: dict - {"data_option": callable}, example {"behavior": LazyVideo} - - Returns - ------- - dict - {data label: callable} - """ - if data_kwargs is None: - data_kwargs = dict() - - if other_data_loaders is None: - other_data_loaders = dict() - - default_extension_kwargs = {k: dict() for k in VALID_DATA_OPTIONS + list(other_data_loaders.keys())} - - default_extension_kwargs["contours"] = {"swap_dim": False} - - ext_kwargs = { - **default_extension_kwargs, - **data_kwargs - } - - projections = {k: partial(series.caiman.get_projection, k) for k in projs} - - other_data_loaders_mapping = dict() - - # make ExtensionCallWrapers for other data loaders - for option in list(other_data_loaders.keys()): - other_data_loaders_mapping[option] = ExtensionCallWrapper(other_data_loaders[option], ext_kwargs[option]) - - rcm_rcb_projs = dict() - for proj in ["mean", "min", "max", "std"]: - rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( - series.cnmf.get_rcm, - ext_kwargs["rcm"], - attr=f"{proj}_image" - ) - - temporal_mappings = { - k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS - } - - m = { - "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), - "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), - "rcb": ExtensionCallWrapper(series.cnmf.get_rcb, ext_kwargs["rcb"]), - "residuals": ExtensionCallWrapper(series.cnmf.get_residuals, ext_kwargs["residuals"]), - "corr": ExtensionCallWrapper(series.caiman.get_corr_image, ext_kwargs["corr"]), - "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), - "empty": ZeroArray, - **temporal_mappings, - **projections, - **rcm_rcb_projs, - **other_data_loaders_mapping - } - - return m - - -# TODO: maybe this can be used so that ImageWidget can be used for both behavior and calcium -# TODO: but then we need an option to set window_funcs separately for each subplot -class TimeArray(LazyArray): - """ - Wrapper for array-like that takes units of millisecond for slicing - """ - def __init__(self, array: Union[np.ndarray, LazyArray], timestamps = None, framerate = None): - """ - Arrays which can be sliced using timepoints in units of millisecond. - Supports slicing with start and stop timepoints, does not support slice steps. - - i.e. You can do this: time_array[30], time_array[30:], time_array[:50], time_array[30:50]. - You cannot do this: time_array[::10], time_array[0::10], time_array[0:50:10] - - Parameters - ---------- - array: array-like - data array, must have shape attribute and first dimension must be frame index - - timestamps: np.ndarray, 1 dimensional - timestamps in units of millisecond, you must provide either timestamps or framerate. - MUST be in order such that t_(n +1) > t_n for all n. - - framerate: float - framerate, in units of Hz (per second). You must provide either timestamps or framerate - """ - self._array = array - - if timestamps is None and framerate is None: - raise ValueError("Must provide timestamps or framerate") - - if timestamps is None: - # total duration in milliseconds = n_frames / framerate - n_frames = self.shape[0] - stop_time_ms = (n_frames / framerate) * 1000 - timestamps = np.linspace( - start=0, - stop=stop_time_ms, - num=n_frames, - endpoint=False - ) - - if timestamps.size != self._array.shape[0]: - raise ValueError("timestamps.size != array.shape[0]") - - self.timestamps = timestamps - - def _get_closest_index(self, timepoint: float): - """ - from: https://stackoverflow.com/a/26026189/4697983 - - This is very fast, 10 microseconds even for a - - Parameters - ---------- - timepoint: float - timepoint in milliseconds - - Returns - ------- - int - index for the closest timestamp, which also corresponds to the frame index of the data array - """ - value = timepoint - array = self.timestamps - - idx = np.searchsorted(array, value, side="left") - if idx > 0 and (idx == len(array) or math.fabs(value - array[idx - 1]) < math.fabs(value - array[idx])): - return idx - 1 - else: - return idx - - # override __getitem__ since it will not work with LazyArray base implementation since: - # 1. base implementation requires the slice indices to be less than shape[0] - # 2. base implementation does not consider slicing with float values - def __getitem__(self, indices: Union[slice, int, float]) -> np.ndarray: - if isinstance(indices, slice): - if indices.step is not None: - raise IndexError( - "TimeArray slicing does not support step, only start and stop. See docstring." - ) - - if indices.start is None: - start = 0 - else: - start = self._get_closest_index(indices.start) - - if indices.stop is None: - stop = self.n_frames - else: - stop = self._get_closest_index(indices.stop) - - s = slice(start, stop) - return self._array[s] - - # single index - index = self._get_closest_index(indices) - return self._array[index] - - def _compute_at_indices(self, indices: Union[int, slice]) -> np.ndarray: - """not implemented here""" - pass - - @property - def n_frames(self) -> int: - return self.shape[0] - - @property - def shape(self) -> Tuple[int, int, int]: - return self._array.shape - - @property - def dtype(self) -> str: - return str(self._array.dtype) - - @property - def min(self) -> float: - if isinstance(self._array, LazyArray): - return self._array.min - else: - return quick_min_max(self._array)[0] - - @property - def max(self) -> float: - if isinstance(self._array, LazyArray): - return self._array.max - else: - return quick_min_max(self._array)[1] - - -class GridPlotWrapper: - """Wraps GridPlot in a way that allows updating the data""" - - def __init__( - self, - data: Union[List[str], List[List[str]]], - data_mapping: Dict[str, ExtensionCallWrapper], - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - # slider_ipywidget: ipywidgets.IntSlider = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random" - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data: list of str or list of list of str - list of data to plot, examples: ["input", "temporal-stack"], [["temporal"], ["rcm", "rcb"]] - - data_mapping: dict - maps {"data_option": callable} - - reset_timepoint_on_change: bool, default False - reset the timepoint in the ImageWidget when changing items/rows - - data_graphic_kwargs: dict - passed add_ for corresponding graphic - - slider_ipywidget: ipywidgets.IntSlider - time slider from ImageWidget - - gridplot_kwargs: dict, optional - kwargs passed to GridPlot - - """ - - self._data = data - - if data_graphic_kwargs is None: - data_graphic_kwargs = dict() - - self.data_graphic_kwargs = data_graphic_kwargs - - if gridplot_kwargs is None: - gridplot_kwargs = dict() - - self._cmap = cmap - - self.component_colors = component_colors - - # self._slider_ipywidget = slider_ipywidget - - self.reset_timepoint_on_change = reset_timepoint_on_change - - self.gridplots: List[GridPlot] = list() - - self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") - self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) - for trait in ["value", "max"]: - jslink((self.component_slider, trait), (self.component_int_box, trait)) - - self.component_int_box.observe(self.set_component_index, "value") - - # gridplot for each sublist - for sub_data in self._data: - _gridplot_kwargs = {"shape": calculate_gridshape(len(sub_data))} - _gridplot_kwargs.update(gridplot_kwargs) - self.gridplots.append(GridPlot(**_gridplot_kwargs)) - - self.temporal_graphics: List[graphics.LineCollection] = list() - self.temporal_stack_graphics: List[graphics.LineStack] = list() - self.heatmap_graphics: List[graphics.HeatmapGraphic] = list() - self.image_graphics: List[graphics.ImageGraphic] = list() - self.contour_graphics: List[graphics.LineCollection] = list() - - self._managed_graphics: List[list] = [ - self.temporal_graphics, - self.temporal_stack_graphics, - self.image_graphics, - self.contour_graphics - ] - - # to store only image data in a 1:1 mapping to the graphics list - self.image_graphic_arrays: List[np.ndarray] = list() - - self.linear_selectors: List[LinearSelector] = list() - - self._current_frame_index: int = 0 - - self.change_data(data_mapping) - - def set_component_index(self, change): - index = change["new"] - - for g in self.contour_graphics: - g._set_feature(feature="colors", new_data="w", indices=index) - - def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: - """ - Returns nested list of array-like - """ - data_arrays = list() - - for d in data_options: - if isinstance(d, list): - data_arrays.append(self._parse_data(d, data_mapping)) - - elif d == "empty": - data_arrays.append(None) - - else: - func = data_mapping[d] - a = func() - data_arrays.append(a) - - return data_arrays - - @property - def cmap(self) -> str: - return self._cmap - - @cmap.setter - def cmap(self, cmap: str): - for g in self.image_graphics: - g.cmap = cmap - - # @property - # def component_colors(self) -> Any: - # pass - # - # @component_colors.setter - # def component_colors(self, colors: Any): - # for collection in self.contour_graphics: - # for g in collection.graphics: - # - - def change_data(self, data_mapping: Dict[str, callable]): - for l in self._managed_graphics: - l.clear() - - self.image_graphic_arrays.clear() - - # clear existing subplots - for gp in self.gridplots: - gp.clear() - - # new data arrays - data_arrays = self._parse_data(data_options=self._data, data_mapping=data_mapping) - - # rval is (contours, centeres of masses) - contours = data_mapping["contours"]()[0] - - if self.component_colors == "random": - n_components = len(contours) - component_colors = np.random.rand(n_components, 4).astype(np.float32) - component_colors[:, -1] = 1 - else: - component_colors = self.component_colors - - self.component_slider.value = 0 - self.component_slider.max = len(contours) - - # change data for all gridplots - for sub_data, sub_data_arrays, gridplot in zip(self._data, data_arrays, self.gridplots): - self._change_data_gridplot(sub_data, sub_data_arrays, gridplot, contours, component_colors) - - # connect events - self._connect_events() - - def _change_data_gridplot( - self, - data: List[str], - data_arrays: List[np.ndarray], - gridplot: GridPlot, - contours, - component_colors - ): - - if self.reset_timepoint_on_change: - self._current_frame_index = 0 - - for data_option, data_array, subplot in zip(data, data_arrays, gridplot): - if data_option in self.data_graphic_kwargs.keys(): - graphic_kwargs = self.data_graphic_kwargs[data_option] - else: - graphic_kwargs = dict() - # skip - if data_option == "empty": - continue - - elif data_option == "temporal": - current_graphic = subplot.add_line_collection( - data_array, - colors=component_colors, - name="components", - **graphic_kwargs - ) - current_graphic[:].present.add_event_handler(subplot.auto_scale) - self.temporal_graphics.append(current_graphic) - - # otherwise the plot has nothing in it which causes issues - subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - elif data_option == "temporal-stack": - current_graphic = subplot.add_line_stack( - data_array, - colors=component_colors, - name="components", - **graphic_kwargs - ) - self.temporal_stack_graphics.append(current_graphic) - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - elif data_option == "heatmap": - current_graphic = subplot.add_heatmap( - data_array, - name="components", - **graphic_kwargs - ) - self.heatmap_graphics.append(current_graphic) - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - else: - img_graphic = subplot.add_image( - data_array[self._current_frame_index], - cmap=self.cmap, - name="image", - **graphic_kwargs - ) - - self.image_graphics.append(img_graphic) - self.image_graphic_arrays.append(data_array) - - contour_graphic = subplot.add_line_collection( - contours, - colors=component_colors, - name="contours" - ) - - self.contour_graphics.append(contour_graphic) - - subplot.name = data_option - - if data_option in TEMPORAL_OPTIONS: - self.linear_selectors.append(current_graphic.add_linear_selector()) - subplot.camera.maintain_aspect = False - - if len(self.linear_selectors) > 0: - self._synchronizer = Synchronizer(*self.linear_selectors) - - for ls in self.linear_selectors: - ls.selection.add_event_handler(self.set_frame_index) - - def _euclidean(self, source, target, event, new_data): - """maps click events to contour""" - # calculate coms of line collection - indices = np.array(event.pick_info["index"]) - - coms = list() - - for contour in target.graphics: - coors = contour.data()[~np.isnan(contour.data()).any(axis=1)] - com = coors.mean(axis=0) - coms.append(com) - - # euclidean distance to find closest index of com - indices = np.append(indices, [0]) - - ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) - - target._set_feature(feature="colors", new_data=new_data, indices=ix) - - self.component_int_box.value = ix - - return None - - def _connect_events(self): - for image_graphic, contour_graphic in zip(self.image_graphics, self.contour_graphics): - image_graphic.link( - "click", - target=contour_graphic, - feature="colors", - new_data="w", - callback=self._euclidean - ) - - contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=5) - - for temporal_graphic in self.temporal_graphics: - contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) - - for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): - cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) - - def set_frame_index(self, ev): - # 0 because this will return the same number repeated * n_components - index = ev.pick_info["selected_index"][0] - for image_graphic, full_array in zip(self.image_graphics, self.image_graphic_arrays): - # txy data - if full_array.ndim > 2: - image_graphic.data = full_array[index] - - self._current_frame_index = index - - -# TODO: This use a GridPlot that's manually managed because the timescales of calcium ad behavior won't match -class CNMFVizContainer: - """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" - - def __init__( - self, - dataframe: pd.DataFrame, - data: List[str] = None, - start_index: int = 0, - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random", - calcium_framerate: float = None, - other_data_loaders: Dict[str, callable] = None, - data_kwargs: dict = None, - data_grid_kwargs: dict = None, - ): - """ - Visualize CNMF output and other data columns such as behavior video (optional) - - Parameters - ---------- - dataframe: pd.DataFrame - - data: list of str - data options, such as "input", "temporal", "contours", etc. - - start_index - - reset_timepoint_on_change - - calcium_framerate - - other_data_loaders: Dict[str, callable] - if loading non-calcium related data arrays, provide dict of callables for opening them. - Example, if you provide ``data = ["contours", "temporal", "behavior"]``, and the "behavior" - column contains videos, you could provide `other_data_loads = {"behavior": LazyVideo} - - data_kwargs: dict - kwargs passed to corresponding extension function to load data. - example: ``{"temporal": {"component_ixs": "good"}}`` - - gridplot_kwargs: List[dict] - kwargs passed to GridPlot - - data_grid_kwargs - """ - - if data is None: - data = [["temporal"], ["input", "rcm", "rcb", "residuals"]] - - if other_data_loaders is None: - other_data_loaders = dict() - - # simple list of str, single gridplot - if all(isinstance(option, str) for option in data): - data = [data] - - if not all(isinstance(option, list) for option in data): - raise TypeError( - "Must pass list of str or nested list of str" - ) - - # make sure data options are valid - for d in list(itertools.chain(*data)): - if (d not in VALID_DATA_OPTIONS) and (d not in dataframe.columns): - raise ValueError( - f"`data` options are: {VALID_DATA_OPTIONS} or a DataFrame column name: {dataframe.columns}\n" - f"You have passed: {d}" - ) - - if d in dataframe.columns: - if d not in other_data_loaders.keys(): - raise ValueError( - f"You have provided the non-CNMF related data option: {d}.\n" - f"If you provide a non-cnmf related data option you must also provide a " - f"data loader callable for it to `other_data_loaders`" - ) - - self._other_data_loaders = other_data_loaders - - if data_grid_kwargs is None: - data_grid_kwargs = dict() - - self._dataframe = dataframe - - default_widths = { - "algo": 50, - 'item_name': 200, - 'input_movie_path': 120, - 'algo_duration': 80, - 'comments': 120, - 'uuid': 60 - } - - columns = dataframe.columns - # these add clutter - hide_columns = [ - "params", - "outputs", - "added_time", - "ran_time", - - ] - - df_show = self._dataframe[[c for c in columns if c not in hide_columns]] - - self.datagrid = DataGrid( - df_show, # show only a subset - selection_mode="cell", - layout={"height": "250px", "width": "750px"}, - base_row_size=24, - index_name="index", - column_widths=default_widths, - **data_grid_kwargs - ) - - self.params_text_area = Textarea() - self.params_text_area.layout = Layout( - height="250px", - max_height="250px", - width="360px", - max_width="500px" - ) - - # data options is private since this can't be changed once an image widget has been made - self._data = data - - if data_kwargs is None: - data_kwargs = dict() - - self.data_kwargs = data_kwargs - - self.current_row: int = start_index - - self._make_gridplot( - start_index=start_index, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors, - ) - - self._set_params_text_area(index=start_index) - - # set initial selected row - self.datagrid.select( - row1=start_index, - column1=0, - row2=start_index, - column2=len(df_show.columns), - clear_mode="all" - ) - - # callback when row changed - self.datagrid.observe(self._row_changed, names="selections") - - def _make_gridplot( - self, - start_index: int, - reset_timepoint_on_change: bool, - data_graphic_kwargs: dict, - gridplot_kwargs: dict, - cmap: str, - component_colors: str, - ): - - data_mapping = get_cnmf_data_mapping( - self._dataframe.iloc[start_index], - self.data_kwargs - ) - - self._gridplot_wrapper = GridPlotWrapper( - data=self._data, - data_mapping=data_mapping, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors - - ) - - self.gridplots = self._gridplot_wrapper.gridplots - - def show(self): - """Show the widget""" - - self.show_all_checkbox = Checkbox(value=True, description="Show all components") - - gridplots_widget = [gp.show() for gp in self.gridplots] - - if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: - vbox_elements = gridplots_widget - else: - vbox_elements = list() - - widget = VBox( - [ - HBox([self.datagrid, self.params_text_area]), - self.show_all_checkbox, - HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), - VBox(vbox_elements) - ] - ) - - self.show_all_checkbox.observe(self._toggle_show_all, "value") - - return widget - - def _toggle_show_all(self, change): - for line_collection in self._gridplot_wrapper.temporal_graphics: - line_collection[:].present = change["new"] - - def close(self): - """Close the widget""" - for gp in self.gridplots: - gp.close() - - def _get_selection_row(self) -> Union[int, None]: - r1 = self.datagrid.selections[0]["r1"] - r2 = self.datagrid.selections[0]["r2"] - - if r1 != r2: - warn("Only single row selection is currently allowed") - return - - # get corresponding dataframe index from currently visible dataframe - # since filtering etc. is possible - index = self.datagrid.get_visible_data().index[r1] - - return index - - def _row_changed(self, *args): - index = self._get_selection_row() - if index is None: - return - - if self.current_row == index: - return - - try: - data_mapping = get_cnmf_data_mapping( - self._dataframe.iloc[index], - self.data_kwargs - ) - self._gridplot_wrapper.change_data(data_mapping) - except Exception as e: - self.params_text_area.value = f"{type(e).__name__}\n" \ - f"{str(e)}\n\n" \ - f"See jupyter log for details" - raise e - - self._set_params_text_area(index) - - self.current_row = index - - def _set_params_text_area(self, index): - row = self._dataframe.iloc[index] - # try and get the param diffs - try: - param_diffs = self._dataframe.caiman.get_params_diffs( - algo=row["algo"], - item_name=row["item_name"] - ).iloc[index] - - diffs_dict = {"diffs": param_diffs} - diffs = f"{format_params(diffs_dict, 0)}\n\n" - except: - diffs = "" - - # diffs and full params - self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) - - -@pd.api.extensions.register_dataframe_accessor("cnmf") -class CNMFDataFrameVizExtension: - def __init__(self, df): - self._dataframe = df - - def viz( - self, - data: List[str] = None, - start_index: int = 0, - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random", - calcium_framerate: float = None, - other_data_loaders: Dict[str, callable] = None, - data_kwargs: dict = None, - data_grid_kwargs: dict = None, - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data: list of str or list of list of str - default [["temporal"], ["input", "rcm", "rcb", "residuals"]] - list of data to plot, valid options are: - - +------------------+-----------------------------------------+ - | "input" | input movie | - +==================+=========================================+ - | "rcm" | reconstructed movie, A * C | - | "rcb" | reconstructed background, b * f | - | "residuals" | residuals, input - (A * C) - (b * f) | - | "corr" | correlation image, if computed | - | "pnr" | peak-noise-ratio image, if computed | - | "temporal" | temporal components overlaid | - | "temporal-stack" | temporal components stack | - | "heatmap" | temporal components heatmap | - | "rcm-mean" | rcm mean projection image | - | "rcm-min" | rcm min projection image | - | "rcm-max" | rcm max projection image | - | "rcm-std" | rcm standard deviation projection image | - | "rcb-mean" | rcb mean projection image | - | "rcb-min" | rcb min projection image | - | "rcb-max" | rcb max projection image | - | "rcb-std" | rcb standard deviation projection image | - | "mean" | mean projection image | - | "max" | max projection image | - | "std" | standard deviation projection image | - +------------------+-----------------------------------------+ - - - start_index: int, default 0 - start index item used to set the initial data in the ImageWidget - - reset_timepoint_on_change: bool, default False - reset the timepoint in the ImageWidget when changing items/rows - - data_grid_kwargs: dict, optional - kwargs passed to DataGrid() - - Returns - ------- - McorrVizContainer - widget that contains the DataGrid, params text box and ImageWidget - """ - container = CNMFVizContainer( - dataframe=self._dataframe, - data=data, - start_index=start_index, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors, - calcium_framerate=calcium_framerate, - other_data_loaders=other_data_loaders, - data_kwargs=data_kwargs, - data_grid_kwargs=data_grid_kwargs, - ) - - return container diff --git a/mesmerize_viz/_cnmf/__init__.py b/mesmerize_viz/_cnmf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mesmerize_viz/_cnmf/_eval.py b/mesmerize_viz/_cnmf/_eval.py new file mode 100644 index 0000000..492229e --- /dev/null +++ b/mesmerize_viz/_cnmf/_eval.py @@ -0,0 +1,71 @@ +from typing import * +from collections import OrderedDict + +from ipywidgets import FloatSlider, FloatText, HBox, VBox, link, Layout, Label + + +class EvalWidgets: + def __init__(self): + # low thresholds + self._low_thresholds = OrderedDict( + rval_lowest=(-1.0, -1.0, 1.0), # (val, min, max) + SNR_lowest=(0.5, 0., 100), + cnn_lowest=(0.1, 0., 1.0), + ) + + # high thresholds + self.high_thresholds = OrderedDict( + rval_thr=(0.8, 0., 1.0), + min_SNR=(2.5, 0., 100), + min_cnn_thr=(0.9, 0., 1.0), + ) + + self._low_threshold_widget = list() + for k in self._low_thresholds: + kwargs = dict(value=self._low_thresholds[k][0], min=self._low_thresholds[k][1], max=self._low_thresholds[k][2], step=0.01, description=k) + slider = FloatSlider(**kwargs) + entry = FloatText(**kwargs, layout=Layout(width="150px")) + + link((slider, "value"), (entry, "value")) + + setattr(self, f"_{k}", entry) + + self._low_threshold_widget.append(HBox([slider, entry])) + + self._high_threshold_widgets = list() + for k in self.high_thresholds: + kwargs = dict(value=self.high_thresholds[k][0], min=self.high_thresholds[k][1], max=self.high_thresholds[k][2], step=0.01, description=k) + slider = FloatSlider(**kwargs) + entry = FloatText(**kwargs, layout=Layout(width="150px")) + + link((slider, "value"), (entry, "value")) + + setattr(self, f"_{k}", entry) + + self._high_threshold_widgets.append(HBox([slider, entry])) + + self.widget = VBox( + [ + Label("Low Thresholds"), + self._low_threshold_widget, + Label("High Thresholds"), + self._high_threshold_widgets + ] + ) + + def get_params(self): + """get the values from the GUI""" + + eval_params = dict() + for param in self._low_thresholds: + eval_params[param] = getattr(self, f"_{param}.value") + + for param in self._high_threshold_widgets: + eval_params[param] = getattr(self, f"_{param}.value") + + return eval_params + + def set_param(self, param: str, value: float): + w = getattr(self, f"_{param}") + + w.value = value diff --git a/mesmerize_viz/_cnmf/_extensions.py b/mesmerize_viz/_cnmf/_extensions.py new file mode 100644 index 0000000..fc837b7 --- /dev/null +++ b/mesmerize_viz/_cnmf/_extensions.py @@ -0,0 +1,90 @@ +from typing import * + +import pandas as pd + +from ._viz_container import CNMFVizContainer + + +@pd.api.extensions.register_dataframe_accessor("cnmf") +class CNMFDataFrameVizExtension: + def __init__(self, df): + self._dataframe = df + + def viz( + self, + data: List[str] = None, + start_index: int = 0, + reset_timepoint_on_change: bool = False, + data_graphic_kwargs: dict = None, + gridplot_kwargs: dict = None, + cmap: str = "gnuplot2", + component_colors: str = "random", + calcium_framerate: float = None, + other_data_loaders: Dict[str, callable] = None, + data_kwargs: dict = None, + data_grid_kwargs: dict = None, + ): + """ + Visualize motion correction output. + + Parameters + ---------- + data: list of str or list of list of str + default [["temporal"], ["input", "rcm", "rcb", "residuals"]] + list of data to plot, valid options are: + + +------------------+-----------------------------------------+ + | "input" | input movie | + +==================+=========================================+ + | "rcm" | reconstructed movie, A * C | + | "rcb" | reconstructed background, b * f | + | "residuals" | residuals, input - (A * C) - (b * f) | + | "corr" | correlation image, if computed | + | "pnr" | peak-noise-ratio image, if computed | + | "temporal" | temporal components overlaid | + | "temporal-stack" | temporal components stack | + | "heatmap" | temporal components heatmap | + | "rcm-mean" | rcm mean projection image | + | "rcm-min" | rcm min projection image | + | "rcm-max" | rcm max projection image | + | "rcm-std" | rcm standard deviation projection image | + | "rcb-mean" | rcb mean projection image | + | "rcb-min" | rcb min projection image | + | "rcb-max" | rcb max projection image | + | "rcb-std" | rcb standard deviation projection image | + | "mean" | mean projection image | + | "max" | max projection image | + | "std" | standard deviation projection image | + +------------------+-----------------------------------------+ + + + start_index: int, default 0 + start index item used to set the initial data in the ImageWidget + + reset_timepoint_on_change: bool, default False + reset the timepoint in the ImageWidget when changing items/rows + + data_grid_kwargs: dict, optional + kwargs passed to DataGrid() + + Returns + ------- + McorrVizContainer + widget that contains the DataGrid, params text box and ImageWidget + """ + container = CNMFVizContainer( + dataframe=self._dataframe, + data=data, + start_index=start_index, + reset_timepoint_on_change=reset_timepoint_on_change, + data_graphic_kwargs=data_graphic_kwargs, + gridplot_kwargs=gridplot_kwargs, + cmap=cmap, + component_colors=component_colors, + calcium_framerate=calcium_framerate, + other_data_loaders=other_data_loaders, + data_kwargs=data_kwargs, + data_grid_kwargs=data_grid_kwargs, + ) + + return container diff --git a/mesmerize_viz/_cnmf/_time_array.py b/mesmerize_viz/_cnmf/_time_array.py new file mode 100644 index 0000000..08143bb --- /dev/null +++ b/mesmerize_viz/_cnmf/_time_array.py @@ -0,0 +1,137 @@ +import math +from typing import Union, Tuple + +import numpy as np + +from mesmerize_core.arrays._base import LazyArray +from mesmerize_core.utils import quick_min_max + + +# TODO: maybe this can be used so that ImageWidget can be used for both behavior and calcium +# TODO: but then we need an option to set window_funcs separately for each subplot +class TimeArray(LazyArray): + """ + Wrapper for array-like that takes units of millisecond for slicing + """ + def __init__(self, array: Union[np.ndarray, LazyArray], timestamps = None, framerate = None): + """ + Arrays which can be sliced using timepoints in units of millisecond. + Supports slicing with start and stop timepoints, does not support slice steps. + + i.e. You can do this: time_array[30], time_array[30:], time_array[:50], time_array[30:50]. + You cannot do this: time_array[::10], time_array[0::10], time_array[0:50:10] + + Parameters + ---------- + array: array-like + data array, must have shape attribute and first dimension must be frame index + + timestamps: np.ndarray, 1 dimensional + timestamps in units of millisecond, you must provide either timestamps or framerate. + MUST be in order such that t_(n +1) > t_n for all n. + + framerate: float + framerate, in units of Hz (per second). You must provide either timestamps or framerate + """ + self._array = array + + if timestamps is None and framerate is None: + raise ValueError("Must provide timestamps or framerate") + + if timestamps is None: + # total duration in milliseconds = n_frames / framerate + n_frames = self.shape[0] + stop_time_ms = (n_frames / framerate) * 1000 + timestamps = np.linspace( + start=0, + stop=stop_time_ms, + num=n_frames, + endpoint=False + ) + + if timestamps.size != self._array.shape[0]: + raise ValueError("timestamps.size != array.shape[0]") + + self.timestamps = timestamps + + def _get_closest_index(self, timepoint: float): + """ + from: https://stackoverflow.com/a/26026189/4697983 + + This is very fast, 10 microseconds even for a + + Parameters + ---------- + timepoint: float + timepoint in milliseconds + + Returns + ------- + int + index for the closest timestamp, which also corresponds to the frame index of the data array + """ + value = timepoint + array = self.timestamps + + idx = np.searchsorted(array, value, side="left") + if idx > 0 and (idx == len(array) or math.fabs(value - array[idx - 1]) < math.fabs(value - array[idx])): + return idx - 1 + else: + return idx + + # override __getitem__ since it will not work with LazyArray base implementation since: + # 1. base implementation requires the slice indices to be less than shape[0] + # 2. base implementation does not consider slicing with float values + def __getitem__(self, indices: Union[slice, int, float]) -> np.ndarray: + if isinstance(indices, slice): + if indices.step is not None: + raise IndexError( + "TimeArray slicing does not support step, only start and stop. See docstring." + ) + + if indices.start is None: + start = 0 + else: + start = self._get_closest_index(indices.start) + + if indices.stop is None: + stop = self.n_frames + else: + stop = self._get_closest_index(indices.stop) + + s = slice(start, stop) + return self._array[s] + + # single index + index = self._get_closest_index(indices) + return self._array[index] + + def _compute_at_indices(self, indices: Union[int, slice]) -> np.ndarray: + """not implemented here""" + pass + + @property + def n_frames(self) -> int: + return self.shape[0] + + @property + def shape(self) -> Tuple[int, int, int]: + return self._array.shape + + @property + def dtype(self) -> str: + return str(self._array.dtype) + + @property + def min(self) -> float: + if isinstance(self._array, LazyArray): + return self._array.min + else: + return quick_min_max(self._array)[0] + + @property + def max(self) -> float: + if isinstance(self._array, LazyArray): + return self._array.max + else: + return quick_min_max(self._array)[1] diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py new file mode 100644 index 0000000..d6390b6 --- /dev/null +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -0,0 +1,287 @@ +import itertools +from _warnings import warn +from typing import * + +import pandas as pd +from ipydatagrid import DataGrid +from ipywidgets import Textarea, Layout, Checkbox, VBox, HBox + +from ._wrapper import VALID_DATA_OPTIONS, get_cnmf_data_mapping, GridPlotWrapper +from .._utils import format_params + + +class CNMFVizContainer: + """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" + + def __init__( + self, + dataframe: pd.DataFrame, + data: List[str] = None, + start_index: int = 0, + reset_timepoint_on_change: bool = False, + data_graphic_kwargs: dict = None, + gridplot_kwargs: dict = None, + cmap: str = "gnuplot2", + component_colors: str = "random", + calcium_framerate: float = None, + other_data_loaders: Dict[str, callable] = None, + data_kwargs: dict = None, + data_grid_kwargs: dict = None, + ): + """ + Visualize CNMF output and other data columns such as behavior video (optional) + + Parameters + ---------- + dataframe: pd.DataFrame + + data: list of str + data options, such as "input", "temporal", "contours", etc. + + start_index + + reset_timepoint_on_change + + calcium_framerate + + other_data_loaders: Dict[str, callable] + if loading non-calcium related data arrays, provide dict of callables for opening them. + Example, if you provide ``data = ["contours", "temporal", "behavior"]``, and the "behavior" + column contains videos, you could provide `other_data_loads = {"behavior": LazyVideo} + + data_kwargs: dict + kwargs passed to corresponding extension function to load data. + example: ``{"temporal": {"component_ixs": "good"}}`` + + gridplot_kwargs: List[dict] + kwargs passed to GridPlot + + data_grid_kwargs + """ + + if data is None: + data = [["temporal"], ["input", "rcm", "rcb", "residuals"]] + + if other_data_loaders is None: + other_data_loaders = dict() + + # simple list of str, single gridplot + if all(isinstance(option, str) for option in data): + data = [data] + + if not all(isinstance(option, list) for option in data): + raise TypeError( + "Must pass list of str or nested list of str" + ) + + # make sure data options are valid + for d in list(itertools.chain(*data)): + if (d not in VALID_DATA_OPTIONS) and (d not in dataframe.columns): + raise ValueError( + f"`data` options are: {VALID_DATA_OPTIONS} or a DataFrame column name: {dataframe.columns}\n" + f"You have passed: {d}" + ) + + if d in dataframe.columns: + if d not in other_data_loaders.keys(): + raise ValueError( + f"You have provided the non-CNMF related data option: {d}.\n" + f"If you provide a non-cnmf related data option you must also provide a " + f"data loader callable for it to `other_data_loaders`" + ) + + self._other_data_loaders = other_data_loaders + + if data_grid_kwargs is None: + data_grid_kwargs = dict() + + self._dataframe = dataframe + + default_widths = { + "algo": 50, + 'item_name': 200, + 'input_movie_path': 120, + 'algo_duration': 80, + 'comments': 120, + 'uuid': 60 + } + + columns = dataframe.columns + # these add clutter + hide_columns = [ + "params", + "outputs", + "added_time", + "ran_time", + + ] + + df_show = self._dataframe[[c for c in columns if c not in hide_columns]] + + self.datagrid = DataGrid( + df_show, # show only a subset + selection_mode="cell", + layout={"height": "250px", "width": "750px"}, + base_row_size=24, + index_name="index", + column_widths=default_widths, + **data_grid_kwargs + ) + + self.params_text_area = Textarea() + self.params_text_area.layout = Layout( + height="250px", + max_height="250px", + width="360px", + max_width="500px" + ) + + # data options is private since this can't be changed once an image widget has been made + self._data = data + + if data_kwargs is None: + data_kwargs = dict() + + self.data_kwargs = data_kwargs + + self.current_row: int = start_index + + self._make_gridplot( + start_index=start_index, + reset_timepoint_on_change=reset_timepoint_on_change, + data_graphic_kwargs=data_graphic_kwargs, + gridplot_kwargs=gridplot_kwargs, + cmap=cmap, + component_colors=component_colors, + ) + + self._set_params_text_area(index=start_index) + + # set initial selected row + self.datagrid.select( + row1=start_index, + column1=0, + row2=start_index, + column2=len(df_show.columns), + clear_mode="all" + ) + + # callback when row changed + self.datagrid.observe(self._row_changed, names="selections") + + def _make_gridplot( + self, + start_index: int, + reset_timepoint_on_change: bool, + data_graphic_kwargs: dict, + gridplot_kwargs: dict, + cmap: str, + component_colors: str, + ): + + data_mapping = get_cnmf_data_mapping( + self._dataframe.iloc[start_index], + self.data_kwargs + ) + + self._gridplot_wrapper = GridPlotWrapper( + data=self._data, + data_mapping=data_mapping, + reset_timepoint_on_change=reset_timepoint_on_change, + data_graphic_kwargs=data_graphic_kwargs, + gridplot_kwargs=gridplot_kwargs, + cmap=cmap, + component_colors=component_colors + + ) + + self.gridplots = self._gridplot_wrapper.gridplots + + def show(self): + """Show the widget""" + + self.show_all_checkbox = Checkbox(value=True, description="Show all components") + + gridplots_widget = [gp.show() for gp in self.gridplots] + + if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: + vbox_elements = gridplots_widget + else: + vbox_elements = list() + + widget = VBox( + [ + HBox([self.datagrid, self.params_text_area]), + self.show_all_checkbox, + HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), + VBox(vbox_elements) + ] + ) + + self.show_all_checkbox.observe(self._toggle_show_all, "value") + + return widget + + def _toggle_show_all(self, change): + for line_collection in self._gridplot_wrapper.temporal_graphics: + line_collection[:].present = change["new"] + + def close(self): + """Close the widget""" + for gp in self.gridplots: + gp.close() + + def _get_selection_row(self) -> Union[int, None]: + r1 = self.datagrid.selections[0]["r1"] + r2 = self.datagrid.selections[0]["r2"] + + if r1 != r2: + warn("Only single row selection is currently allowed") + return + + # get corresponding dataframe index from currently visible dataframe + # since filtering etc. is possible + index = self.datagrid.get_visible_data().index[r1] + + return index + + def _row_changed(self, *args): + index = self._get_selection_row() + if index is None: + return + + if self.current_row == index: + return + + try: + data_mapping = get_cnmf_data_mapping( + self._dataframe.iloc[index], + self.data_kwargs + ) + self._gridplot_wrapper.change_data(data_mapping) + except Exception as e: + self.params_text_area.value = f"{type(e).__name__}\n" \ + f"{str(e)}\n\n" \ + f"See jupyter log for details" + raise e + + self._set_params_text_area(index) + + self.current_row = index + + def _set_params_text_area(self, index): + row = self._dataframe.iloc[index] + # try and get the param diffs + try: + param_diffs = self._dataframe.caiman.get_params_diffs( + algo=row["algo"], + item_name=row["item_name"] + ).iloc[index] + + diffs_dict = {"diffs": param_diffs} + diffs = f"{format_params(diffs_dict, 0)}\n\n" + except: + diffs = "" + + # diffs and full params + self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py new file mode 100644 index 0000000..6cf0582 --- /dev/null +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -0,0 +1,478 @@ +from functools import partial +from itertools import product +from typing import Union, List, Dict + +import numpy as np +import pandas as pd + +from ipywidgets import IntSlider, BoundedIntText, jslink + +from fastplotlib import GridPlot, graphics +from fastplotlib.graphics.selectors import LinearSelector, Synchronizer +from fastplotlib.utils import calculate_gridshape + + +# basic data options +VALID_DATA_OPTIONS = [ + "contours", + "empty" +] + + +IMAGE_OPTIONS = [ + "input", + "rcm", + "rcb", + "residuals", + "corr", + "pnr", +] + +rcm_rcb_proj_options = list() +# RCM and RCB projections +for option in ["rcm", "rcb"]: + for proj in ["mean", "min", "max", "std"]: + rcm_rcb_proj_options.append(f"{option}-{proj}") + +IMAGE_OPTIONS += rcm_rcb_proj_options + +TEMPORAL_OPTIONS = [ + "temporal", + "temporal-stack", + "heatmap", +] + +projs = [ + "mean", + "max", + "std", +] + +IMAGE_OPTIONS += projs + +VALID_DATA_OPTIONS += IMAGE_OPTIONS +VALID_DATA_OPTIONS += TEMPORAL_OPTIONS + + +class ExtensionCallWrapper: + def __init__(self, extension_func: callable, kwargs: dict = None, attr: str = None): + """ + Basically like ``functools.partial`` but supports kwargs. + + Parameters + ---------- + extension_func: callable + extension function reference + + kwargs: dict + kwargs to pass to the extension function when it is called + + attr: str, optional + return an attribute of the callable's output instead of the return value of the callable. + Example: if using rcm, can set ``attr="max_image"`` to return the max proj of the RCM. + """ + + if kwargs is None: + self.kwargs = dict() + else: + self.kwargs = kwargs + + self.func = extension_func + self.attr = attr + + def __call__(self, *args, **kwargs): + rval = self.func(**self.kwargs) + + if self.attr is not None: + return getattr(rval, self.attr) + + return rval + + +def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_data_loaders: dict = None) -> dict: + """ + Returns dict that maps data option str to a callable that can return the corresponding data array. + + For example, ``{"input": series.get_input_movie}`` maps "input" -> series.get_input_movie + + Parameters + ---------- + series: pd.Series + row/item to get mapping from + + data_kwargs: dict, optional + optional kwargs for each of the extension functions + + other_data_loaders: dict + {"data_option": callable}, example {"behavior": LazyVideo} + + Returns + ------- + dict + {data label: callable} + """ + if data_kwargs is None: + data_kwargs = dict() + + if other_data_loaders is None: + other_data_loaders = dict() + + default_extension_kwargs = {k: dict() for k in VALID_DATA_OPTIONS + list(other_data_loaders.keys())} + + default_extension_kwargs["contours"] = {"swap_dim": False} + + ext_kwargs = { + **default_extension_kwargs, + **data_kwargs + } + + projections = {k: partial(series.caiman.get_projection, k) for k in projs} + + other_data_loaders_mapping = dict() + + # make ExtensionCallWrapers for other data loaders + for option in list(other_data_loaders.keys()): + other_data_loaders_mapping[option] = ExtensionCallWrapper(other_data_loaders[option], ext_kwargs[option]) + + rcm_rcb_projs = dict() + for proj in ["mean", "min", "max", "std"]: + rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( + series.cnmf.get_rcm, + ext_kwargs["rcm"], + attr=f"{proj}_image" + ) + + temporal_mappings = { + k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS + } + + m = { + "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), + "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), + "rcb": ExtensionCallWrapper(series.cnmf.get_rcb, ext_kwargs["rcb"]), + "residuals": ExtensionCallWrapper(series.cnmf.get_residuals, ext_kwargs["residuals"]), + "corr": ExtensionCallWrapper(series.caiman.get_corr_image, ext_kwargs["corr"]), + "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), + "empty": None, + **temporal_mappings, + **projections, + **rcm_rcb_projs, + **other_data_loaders_mapping + } + + return m + + +class GridPlotWrapper: + """Wraps GridPlot in a way that allows updating the data""" + + def __init__( + self, + data: Union[List[str], List[List[str]]], + data_mapping: Dict[str, ExtensionCallWrapper], + reset_timepoint_on_change: bool = False, + data_graphic_kwargs: dict = None, + # slider_ipywidget: ipywidgets.IntSlider = None, + gridplot_kwargs: dict = None, + cmap: str = "gnuplot2", + component_colors: str = "random" + ): + """ + Visualize motion correction output. + + Parameters + ---------- + data: list of str or list of list of str + list of data to plot, examples: ["input", "temporal-stack"], [["temporal"], ["rcm", "rcb"]] + + data_mapping: dict + maps {"data_option": callable} + + reset_timepoint_on_change: bool, default False + reset the timepoint in the ImageWidget when changing items/rows + + data_graphic_kwargs: dict + passed add_ for corresponding graphic + + slider_ipywidget: ipywidgets.IntSlider + time slider from ImageWidget + + gridplot_kwargs: dict, optional + kwargs passed to GridPlot + + """ + + self._data = data + + if data_graphic_kwargs is None: + data_graphic_kwargs = dict() + + self.data_graphic_kwargs = data_graphic_kwargs + + if gridplot_kwargs is None: + gridplot_kwargs = dict() + + self._cmap = cmap + + self.component_colors = component_colors + + # self._slider_ipywidget = slider_ipywidget + + self.reset_timepoint_on_change = reset_timepoint_on_change + + self.gridplots: List[GridPlot] = list() + + self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") + self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) + for trait in ["value", "max"]: + jslink((self.component_slider, trait), (self.component_int_box, trait)) + + self.component_int_box.observe(self.set_component_index, "value") + + # gridplot for each sublist + for sub_data in self._data: + _gridplot_kwargs = {"shape": calculate_gridshape(len(sub_data))} + _gridplot_kwargs.update(gridplot_kwargs) + self.gridplots.append(GridPlot(**_gridplot_kwargs)) + + self.temporal_graphics: List[graphics.LineCollection] = list() + self.temporal_stack_graphics: List[graphics.LineStack] = list() + self.heatmap_graphics: List[graphics.HeatmapGraphic] = list() + self.image_graphics: List[graphics.ImageGraphic] = list() + self.contour_graphics: List[graphics.LineCollection] = list() + + self._managed_graphics: List[list] = [ + self.temporal_graphics, + self.temporal_stack_graphics, + self.image_graphics, + self.contour_graphics + ] + + # to store only image data in a 1:1 mapping to the graphics list + self.image_graphic_arrays: List[np.ndarray] = list() + + self.linear_selectors: List[LinearSelector] = list() + + self._current_frame_index: int = 0 + + self.change_data(data_mapping) + + def set_component_index(self, change): + index = change["new"] + + for g in self.contour_graphics: + g._set_feature(feature="colors", new_data="w", indices=index) + + def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: + """ + Returns nested list of array-like + """ + data_arrays = list() + + for d in data_options: + if isinstance(d, list): + data_arrays.append(self._parse_data(d, data_mapping)) + + elif d == "empty": + data_arrays.append(None) + + else: + func = data_mapping[d] + a = func() + data_arrays.append(a) + + return data_arrays + + @property + def cmap(self) -> str: + return self._cmap + + @cmap.setter + def cmap(self, cmap: str): + for g in self.image_graphics: + g.cmap = cmap + + # @property + # def component_colors(self) -> Any: + # pass + # + # @component_colors.setter + # def component_colors(self, colors: Any): + # for collection in self.contour_graphics: + # for g in collection.graphics: + # + + def change_data(self, data_mapping: Dict[str, callable]): + for l in self._managed_graphics: + l.clear() + + self.image_graphic_arrays.clear() + + # clear existing subplots + for gp in self.gridplots: + gp.clear() + + # new data arrays + data_arrays = self._parse_data(data_options=self._data, data_mapping=data_mapping) + + # rval is (contours, centeres of masses) + contours = data_mapping["contours"]()[0] + + if self.component_colors == "random": + n_components = len(contours) + component_colors = np.random.rand(n_components, 4).astype(np.float32) + component_colors[:, -1] = 1 + else: + component_colors = self.component_colors + + self.component_slider.value = 0 + self.component_slider.max = len(contours) + + # change data for all gridplots + for sub_data, sub_data_arrays, gridplot in zip(self._data, data_arrays, self.gridplots): + self._change_data_gridplot(sub_data, sub_data_arrays, gridplot, contours, component_colors) + + # connect events + self._connect_events() + + def _change_data_gridplot( + self, + data: List[str], + data_arrays: List[np.ndarray], + gridplot: GridPlot, + contours, + component_colors + ): + + if self.reset_timepoint_on_change: + self._current_frame_index = 0 + + for data_option, data_array, subplot in zip(data, data_arrays, gridplot): + if data_option in self.data_graphic_kwargs.keys(): + graphic_kwargs = self.data_graphic_kwargs[data_option] + else: + graphic_kwargs = dict() + # skip + if data_option == "empty": + continue + + elif data_option == "temporal": + current_graphic = subplot.add_line_collection( + data_array, + colors=component_colors, + name="components", + **graphic_kwargs + ) + current_graphic[:].present.add_event_handler(subplot.auto_scale) + self.temporal_graphics.append(current_graphic) + + # otherwise the plot has nothing in it which causes issues + subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") + + # scale according to temporal dims + subplot.camera.maintain_aspect = False + + elif data_option == "temporal-stack": + current_graphic = subplot.add_line_stack( + data_array, + colors=component_colors, + name="components", + **graphic_kwargs + ) + self.temporal_stack_graphics.append(current_graphic) + + # scale according to temporal dims + subplot.camera.maintain_aspect = False + + elif data_option == "heatmap": + current_graphic = subplot.add_heatmap( + data_array, + name="components", + **graphic_kwargs + ) + self.heatmap_graphics.append(current_graphic) + + # scale according to temporal dims + subplot.camera.maintain_aspect = False + + else: + img_graphic = subplot.add_image( + data_array[self._current_frame_index], + cmap=self.cmap, + name="image", + **graphic_kwargs + ) + + self.image_graphics.append(img_graphic) + self.image_graphic_arrays.append(data_array) + + contour_graphic = subplot.add_line_collection( + contours, + colors=component_colors, + name="contours" + ) + + self.contour_graphics.append(contour_graphic) + + subplot.name = data_option + + if data_option in TEMPORAL_OPTIONS: + self.linear_selectors.append(current_graphic.add_linear_selector()) + subplot.camera.maintain_aspect = False + + if len(self.linear_selectors) > 0: + self._synchronizer = Synchronizer(*self.linear_selectors) + + for ls in self.linear_selectors: + ls.selection.add_event_handler(self.set_frame_index) + + def _euclidean(self, source, target, event, new_data): + """maps click events to contour""" + # calculate coms of line collection + indices = np.array(event.pick_info["index"]) + + coms = list() + + for contour in target.graphics: + coors = contour.data()[~np.isnan(contour.data()).any(axis=1)] + com = coors.mean(axis=0) + coms.append(com) + + # euclidean distance to find closest index of com + indices = np.append(indices, [0]) + + ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) + + target._set_feature(feature="colors", new_data=new_data, indices=ix) + + self.component_int_box.value = ix + + return None + + def _connect_events(self): + for image_graphic, contour_graphic in zip(self.image_graphics, self.contour_graphics): + image_graphic.link( + "click", + target=contour_graphic, + feature="colors", + new_data="w", + callback=self._euclidean + ) + + contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=5) + + for temporal_graphic in self.temporal_graphics: + contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) + + for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): + cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) + + def set_frame_index(self, ev): + # 0 because this will return the same number repeated * n_components + index = ev.pick_info["selected_index"][0] + for image_graphic, full_array in zip(self.image_graphics, self.image_graphic_arrays): + # txy data + if full_array.ndim > 2: + image_graphic.data = full_array[index] + + self._current_frame_index = index From 7ba3c0b737b3ee3743de70c05b5c9b7ec70aa7ac Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 18 Sep 2023 05:49:42 -0400 Subject: [PATCH 04/36] progress, heatmap works with selectors --- mesmerize_viz/_cnmf/__init__.py | 1 + mesmerize_viz/_cnmf/_viz_container.py | 14 +-- mesmerize_viz/_cnmf/_wrapper.py | 130 ++++++++++++++++++++------ 3 files changed, 107 insertions(+), 38 deletions(-) diff --git a/mesmerize_viz/_cnmf/__init__.py b/mesmerize_viz/_cnmf/__init__.py index e69de29..a93e4fd 100644 --- a/mesmerize_viz/_cnmf/__init__.py +++ b/mesmerize_viz/_cnmf/__init__.py @@ -0,0 +1 @@ +from ._extensions import CNMFDataFrameVizExtension \ No newline at end of file diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index d6390b6..b838988 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -4,7 +4,7 @@ import pandas as pd from ipydatagrid import DataGrid -from ipywidgets import Textarea, Layout, Checkbox, VBox, HBox +from ipywidgets import Textarea, Layout, VBox, HBox from ._wrapper import VALID_DATA_OPTIONS, get_cnmf_data_mapping, GridPlotWrapper from .._utils import format_params @@ -200,8 +200,6 @@ def _make_gridplot( def show(self): """Show the widget""" - self.show_all_checkbox = Checkbox(value=True, description="Show all components") - gridplots_widget = [gp.show() for gp in self.gridplots] if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: @@ -211,21 +209,13 @@ def show(self): widget = VBox( [ - HBox([self.datagrid, self.params_text_area]), - self.show_all_checkbox, - HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), + HBox([self.datagrid, self.params_text_area]), HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), VBox(vbox_elements) ] ) - self.show_all_checkbox.observe(self._toggle_show_all, "value") - return widget - def _toggle_show_all(self, change): - for line_collection in self._gridplot_wrapper.temporal_graphics: - line_collection[:].present = change["new"] - def close(self): """Close the widget""" for gp in self.gridplots: diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 6cf0582..1e409a6 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -8,7 +8,7 @@ from ipywidgets import IntSlider, BoundedIntText, jslink from fastplotlib import GridPlot, graphics -from fastplotlib.graphics.selectors import LinearSelector, Synchronizer +from fastplotlib.graphics.selectors import LinearSelector, Synchronizer, LinearRegionSelector from fastplotlib.utils import calculate_gridshape @@ -227,7 +227,9 @@ def __init__( for trait in ["value", "max"]: jslink((self.component_slider, trait), (self.component_int_box, trait)) - self.component_int_box.observe(self.set_component_index, "value") + self.component_int_box.observe( + lambda change: self.set_component_index(change["new"]), "value" + ) # gridplot for each sublist for sub_data in self._data: @@ -235,12 +237,14 @@ def __init__( _gridplot_kwargs.update(gridplot_kwargs) self.gridplots.append(GridPlot(**_gridplot_kwargs)) - self.temporal_graphics: List[graphics.LineCollection] = list() + self.temporal_graphics: List[graphics.LineGraphic] = list() self.temporal_stack_graphics: List[graphics.LineStack] = list() self.heatmap_graphics: List[graphics.HeatmapGraphic] = list() self.image_graphics: List[graphics.ImageGraphic] = list() self.contour_graphics: List[graphics.LineCollection] = list() + self.heatmap_selectors: List[LinearSelector] = list() + self._managed_graphics: List[list] = [ self.temporal_graphics, self.temporal_stack_graphics, @@ -255,14 +259,30 @@ def __init__( self._current_frame_index: int = 0 - self.change_data(data_mapping) + self._current_temporal_components: np.ndarray = None - def set_component_index(self, change): - index = change["new"] + self.change_data(data_mapping) + def set_component_index(self, index: int): + # TODO: more elegant way than skip_heatmap for g in self.contour_graphics: g._set_feature(feature="colors", new_data="w", indices=index) + for g in self.temporal_graphics: + g.data = self._current_temporal_components[index] + + for s in self.heatmap_selectors: + # TODO: Very hacky for now, ignores if the slider is currently being moved, prevents weird slider movement + if s._move_info is None: + s.selection = index + + self.component_int_box.value = index + + def _heatmap_set_component_index(self, ev): + index = ev.pick_info["selected_index"] + + self.set_component_index(index) + def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: """ Returns nested list of array-like @@ -303,11 +323,30 @@ def cmap(self, cmap: str): # def change_data(self, data_mapping: Dict[str, callable]): + """ + Changes the data shown in the gridplot. + + Clears all the gridplots, makes and adds new graphics + + Parameters + ---------- + data_mapping + + Returns + ------- + + """ for l in self._managed_graphics: l.clear() + self.heatmap_selectors.clear() + self.linear_selectors.clear() + self.image_graphic_arrays.clear() + # clear out old array that stores temporal components + self._current_temporal_components = None + # clear existing subplots for gp in self.gridplots: gp.clear() @@ -326,7 +365,7 @@ def change_data(self, data_mapping: Dict[str, callable]): component_colors = self.component_colors self.component_slider.value = 0 - self.component_slider.max = len(contours) + self.component_slider.max = len(contours) - 1 # change data for all gridplots for sub_data, sub_data_arrays, gridplot in zip(self._data, data_arrays, self.gridplots): @@ -335,6 +374,16 @@ def change_data(self, data_mapping: Dict[str, callable]): # connect events self._connect_events() + # sync sliders if multiple are present + if len(self.linear_selectors) > 0: + self._synchronizer = Synchronizer(*self.linear_selectors, key_bind=None) + + for ls in self.linear_selectors: + ls.selection.add_event_handler(self.set_frame_index) + + for hs in self.heatmap_selectors: + hs.selection.add_event_handler(self._heatmap_set_component_index) + def _change_data_gridplot( self, data: List[str], @@ -343,6 +392,23 @@ def _change_data_gridplot( contours, component_colors ): + """ + Changes data in a single gridplot. + + Create the corresponding graphics. + + Parameters + ---------- + data + data_arrays + gridplot + contours + component_colors + + Returns + ------- + + """ if self.reset_timepoint_on_change: self._current_frame_index = 0 @@ -357,17 +423,22 @@ def _change_data_gridplot( continue elif data_option == "temporal": - current_graphic = subplot.add_line_collection( - data_array, - colors=component_colors, - name="components", + # Only few one line at a time + current_graphic = subplot.add_line( + data_array[0], + colors="w", + name="line", **graphic_kwargs ) - current_graphic[:].present.add_event_handler(subplot.auto_scale) + + current_graphic.data.add_event_handler(subplot.auto_scale) self.temporal_graphics.append(current_graphic) + if self._current_temporal_components is None: + self._current_temporal_components = data_array + # otherwise the plot has nothing in it which causes issues - subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") + # subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") # scale according to temporal dims subplot.camera.maintain_aspect = False @@ -376,7 +447,7 @@ def _change_data_gridplot( current_graphic = subplot.add_line_stack( data_array, colors=component_colors, - name="components", + name="lines", **graphic_kwargs ) self.temporal_stack_graphics.append(current_graphic) @@ -387,7 +458,7 @@ def _change_data_gridplot( elif data_option == "heatmap": current_graphic = subplot.add_heatmap( data_array, - name="components", + name="heatmap", **graphic_kwargs ) self.heatmap_graphics.append(current_graphic) @@ -395,9 +466,22 @@ def _change_data_gridplot( # scale according to temporal dims subplot.camera.maintain_aspect = False + selector = current_graphic.add_linear_selector( + axis="y", + color=(1, 1, 1, 0.5), + thickness=5, + ) + + self.heatmap_selectors.append(selector) + else: + # else it is an image + if data_array.ndim == 3: + frame = data_array[self._current_frame_index] + else: + frame = data_array img_graphic = subplot.add_image( - data_array[self._current_frame_index], + frame, cmap=self.cmap, name="image", **graphic_kwargs @@ -420,12 +504,6 @@ def _change_data_gridplot( self.linear_selectors.append(current_graphic.add_linear_selector()) subplot.camera.maintain_aspect = False - if len(self.linear_selectors) > 0: - self._synchronizer = Synchronizer(*self.linear_selectors) - - for ls in self.linear_selectors: - ls.selection.add_event_handler(self.set_frame_index) - def _euclidean(self, source, target, event, new_data): """maps click events to contour""" # calculate coms of line collection @@ -443,7 +521,7 @@ def _euclidean(self, source, target, event, new_data): ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) - target._set_feature(feature="colors", new_data=new_data, indices=ix) + self.set_component_index(ix) self.component_int_box.value = ix @@ -461,15 +539,15 @@ def _connect_events(self): contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=5) - for temporal_graphic in self.temporal_graphics: - contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) + # for temporal_graphic in self.temporal_graphics: + # contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) def set_frame_index(self, ev): # 0 because this will return the same number repeated * n_components - index = ev.pick_info["selected_index"][0] + index = ev.pick_info["selected_index"] for image_graphic, full_array in zip(self.image_graphics, self.image_graphic_arrays): # txy data if full_array.ndim > 2: From 8e8bdf011dac0290ebe2323318eea5ff241b8a6c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 20 Sep 2023 20:24:16 -0400 Subject: [PATCH 05/36] mean window slider for mcorr widget --- mesmerize_viz/_mcorr.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 26e5fc3..71445cb 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -10,7 +10,7 @@ from fastplotlib import ImageWidget from ipydatagrid import DataGrid -from ipywidgets import Textarea, VBox, HBox, Layout +from ipywidgets import Textarea, VBox, HBox, Layout, IntSlider, Checkbox from ._utils import validate_data_options, ZeroArray, format_params from ._common import ImageWidgetWrapper @@ -184,6 +184,14 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") + def _set_mean_window_size(self, change): + self.image_widget.window_funcs = {"t": (np.mean, change["new"])} + self.image_widget.current_index = self.image_widget.current_index + + def _set_mean_diff(self, change): + # TODO: will do later + pass + def _make_image_widget(self, index): self._image_widget_wrapper = ImageWidgetWrapper( data=self._data, @@ -196,6 +204,22 @@ def _make_image_widget(self, index): self.image_widget = self._image_widget_wrapper.image_widget + # mean window slider + self._slider_mean_window = IntSlider( + min=1, + step=2, + max=99, + value=33, + description="mean wind", + description_tooltip="set a mean rolling window" + ) + self._slider_mean_window.observe(self._set_mean_window_size, "value") + + # TODO: mean diff checkbox + #self._checkbox_mean_diff + + self.image_widget.window_funcs = {"t": (np.mean, self._slider_mean_window.value)} + def _get_selection_row(self) -> Union[int, None]: r1 = self.datagrid.selections[0]["r1"] r2 = self.datagrid.selections[0]["r2"] @@ -262,7 +286,8 @@ def show(self): return VBox([ HBox([self.datagrid, self.params_text_area]), - self.image_widget.show() + self.image_widget.show(), + self._slider_mean_window ]) From 9a702d470f354d58066d1d2a994a473cb9965c0c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 12 Oct 2023 20:26:47 -0400 Subject: [PATCH 06/36] sidecar for cnmf viz, update cnmf viz to latest fpl --- mesmerize_viz/_cnmf/_viz_container.py | 15 ++++++++++++--- mesmerize_viz/_cnmf/_wrapper.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index b838988..1621520 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -5,6 +5,8 @@ import pandas as pd from ipydatagrid import DataGrid from ipywidgets import Textarea, Layout, VBox, HBox +from IPython.display import display +from sidecar import Sidecar from ._wrapper import VALID_DATA_OPTIONS, get_cnmf_data_mapping, GridPlotWrapper from .._utils import format_params @@ -169,6 +171,8 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") + self.sidecar = None + def _make_gridplot( self, start_index: int, @@ -197,16 +201,20 @@ def _make_gridplot( self.gridplots = self._gridplot_wrapper.gridplots - def show(self): + def show(self, sidecar: bool = True): """Show the widget""" - gridplots_widget = [gp.show() for gp in self.gridplots] + # create gridplots and start render loop + gridplots_widget = [gp.show(sidecar=False) for gp in self.gridplots] if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: vbox_elements = gridplots_widget else: vbox_elements = list() + if self.sidecar is None: + self.sidecar = Sidecar() + widget = VBox( [ HBox([self.datagrid, self.params_text_area]), HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), @@ -214,7 +222,8 @@ def show(self): ] ) - return widget + with self.sidecar: + return display(widget) def close(self): """Close the widget""" diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 1e409a6..94ffd4f 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -266,7 +266,7 @@ def __init__( def set_component_index(self, index: int): # TODO: more elegant way than skip_heatmap for g in self.contour_graphics: - g._set_feature(feature="colors", new_data="w", indices=index) + g.set_feature(feature="colors", new_data="w", indices=index) for g in self.temporal_graphics: g.data = self._current_temporal_components[index] From 7e45a605de153d9486f1be8f9a8439852d791dec Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 13 Oct 2023 20:50:17 -0400 Subject: [PATCH 07/36] major refactor of mcorr, use ImageWidget.set_data(), so much simpler now :D --- mesmerize_viz/__init__.py | 2 +- mesmerize_viz/_common.py | 173 ----------------------- mesmerize_viz/_mcorr.py | 289 +++++++++++++++----------------------- mesmerize_viz/_utils.py | 94 +++++++------ 4 files changed, 171 insertions(+), 387 deletions(-) delete mode 100644 mesmerize_viz/_common.py diff --git a/mesmerize_viz/__init__.py b/mesmerize_viz/__init__.py index bac607d..c15d829 100644 --- a/mesmerize_viz/__init__.py +++ b/mesmerize_viz/__init__.py @@ -1,2 +1,2 @@ -from ._mcorr import MCorrExtensionsViz, MCorrDataFrameVizExtension +from ._mcorr import MCorrDataFrameVizExtension from ._cnmf import CNMFDataFrameVizExtension \ No newline at end of file diff --git a/mesmerize_viz/_common.py b/mesmerize_viz/_common.py deleted file mode 100644 index d3bdaba..0000000 --- a/mesmerize_viz/_common.py +++ /dev/null @@ -1,173 +0,0 @@ -from typing import * -from functools import partial - -import numpy as np -import pandas as pd -from fastplotlib import ImageWidget -from fastplotlib.utils import quick_min_max -from mesmerize_core import MCorrExtensions - -from ._utils import ZeroArray - - -class ImageWidgetWrapper: - """Wraps Image Widget in a way that allows updating the data""" - def __init__( - self, - data: List[str], - data_mapping: dict, - image_widget_managed_data: List[str], - reset_timepoint_on_change: bool = False, - input_movie_kwargs: dict = None, - image_widget_kwargs: dict = None, - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data: list of str, default ["input", "mcorr"] - list of data to plot, can also be a list of lists. - - reset_timepoint_on_change: bool, default False - reset the timepoint in the ImageWidget when changing items/rows - - input_movie_kwargs: dict, optional - kwargs passed to get_input_movie() - - image_widget_kwargs: dict, optional - kwargs passed to ImageWidget - - Returns - ------- - ImageWidget - fastplotlib.ImageWidget visualization - """ - - if input_movie_kwargs is None: - input_movie_kwargs = dict() - - if image_widget_kwargs is None: - image_widget_kwargs = dict() - - # the ones which are [t, x, y] images that ImageWidget can manage by itself - self.image_widget_managed_data = image_widget_managed_data - - # data arrays directly passed to image widget - data_arrays_iw = self._parse_data( - data=data, - data_mapping=data_mapping, - input_movie_kwargs=input_movie_kwargs - ) - - # default kwargs unless user has specified - default_iw_kwargs = { - "window_funcs": {"t": (np.mean, 11)}, - "vmin_vmax_sliders": True, - "cmap": "gnuplot2" - } - - image_widget_kwargs = { - **default_iw_kwargs, - **image_widget_kwargs - } - - self.image_widget = ImageWidget( - data=data_arrays_iw, - names=data, - **image_widget_kwargs - ) - - for a, n in zip(data_arrays_iw, data): - if isinstance(a, ZeroArray): - # rename the existing graphic - self.image_widget.plot[n].graphics[0].name = "zero-array-ignore" - # get the real data - func = data_mapping[n] - real_data = func() - # create graphic with the real data, this will not be managed by ImageWidget - self.image_widget.plot[n].add_image(real_data, name="image", cmap="gnuplot2") - - self.reset_timepoint_on_change = reset_timepoint_on_change - - def _parse_data( - self, - data: List[str], - data_mapping: dict, - input_movie_kwargs: dict, - ) -> List[Union[np.ndarray, ZeroArray]]: - """ - Parse data string keys into actual arrays using the data_mapping for the current row. - Returns list of arrays that ImageWidget can display and manage. - """ - # data arrays directly passed to image widget - data_arrays_iw = list() - - for d in data: - if d in self.image_widget_managed_data: - func = data_mapping[d] - - if d == "input": - a = func(**input_movie_kwargs) - else: - a = func() - - data_arrays_iw.append(a) - - elif d == "empty": - zero_array = ZeroArray(ndim=data_arrays_iw[0].ndim) - data_arrays_iw.append(zero_array) - - else: - # make a placeholder array to keep imagewidget happy - # hacky but this is the best way for now - zero_array = ZeroArray(ndim=data_arrays_iw[0].ndim) - data_arrays_iw.append(zero_array) - - return data_arrays_iw - - def change_data(self, data: List[str], data_mapping: dict, input_movie_kwargs): - data_arrays_iw = self._parse_data( - data=data, - data_mapping=data_mapping, - input_movie_kwargs=input_movie_kwargs - ) - - if not len(data) == len(data_arrays_iw): - raise ValueError("len(data) != len(data_arrays)") - - for i, (name, array) in enumerate(zip(data, data_arrays_iw)): - # skip the ones which ImageWidget does not manage - if name in self.image_widget_managed_data: - # update the ones which ImageWidget manages - self.image_widget._data[i] = array - - # I think it's useful to NOT reset the vmin vmax - # if necessary the user can call ImageWidget.reset_vmin_vmax() - - # update the ones which ImageWidget does not manage - self._set_non_managed_arrays( - data=data, - data_arrays_iw=data_arrays_iw, - data_mapping=data_mapping - ) - - if self.reset_timepoint_on_change: - # set index {t: 0} - self.image_widget.current_index = {"t": 0} - else: - # forces graphic data to update in all subplots - self.image_widget.current_index = self.image_widget.current_index - - def _set_non_managed_arrays(self, data, data_arrays_iw, data_mapping): - for a, n in zip(data_arrays_iw, data): - if isinstance(a, ZeroArray): - # get the real data - func = data_mapping[n] - real_data = func() - # change the graphic data - self.image_widget.plot[n]["image"].data = real_data - - min_max = quick_min_max(real_data) - self.image_widget.plot[n]["image"].vmin = min_max[0] - self.image_widget.plot[n]["image"].vmax = min_max[1] diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 71445cb..863ff11 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from typing import * from functools import partial from warnings import warn @@ -6,18 +5,14 @@ import numpy as np import pandas as pd from mesmerize_core import MCorrExtensions -from mesmerize_core.caiman_extensions._utils import validate as validate_algo from fastplotlib import ImageWidget +from sidecar import Sidecar +from IPython.display import display from ipydatagrid import DataGrid from ipywidgets import Textarea, VBox, HBox, Layout, IntSlider, Checkbox -from ._utils import validate_data_options, ZeroArray, format_params -from ._common import ImageWidgetWrapper - - -# these are directly manged by the image widget since they are [t, x, y] -image_widget_managed = ["input", "mcorr"] +from ._utils import DummyMovie, format_params projs = [ @@ -26,6 +21,13 @@ "std", ] +VALID_DATA_OPTIONS = ( + "input", + "mcorr", + "corr", + *projs, +) + def get_mcorr_data_mapping(series: pd.Series) -> dict: """ @@ -60,7 +62,7 @@ class McorrVizContainer: def __init__( self, dataframe: pd.DataFrame, - data: List[str] = None, + data_options: List[str] = None, start_index: int = 0, reset_timepoint_on_change: bool = False, input_movie_kwargs: dict = None, @@ -72,7 +74,7 @@ def __init__( Parameters ---------- - data: list of str, default ["input", "mcorr", "mean", "corr"] + data_options: list of str, default ["input", "mcorr", "mean", "corr"] list of data to plot, valid options are: +-------------+-------------------------------------+ @@ -102,9 +104,9 @@ def __init__( data_grid_kwargs: dict, optional kwargs passed to DataGrid() """ - if data is None: + if data_options is None: # default viz - data = ["input", "mcorr", "mean", "corr"] + data_options = ["input", "mcorr", "mean", "corr"] if data_grid_kwargs is None: data_grid_kwargs = dict() @@ -151,7 +153,7 @@ def __init__( ) # data options is private since this can't be changed once an image widget has been made - self._data = data + self._data_options = data_options if input_movie_kwargs is None: input_movie_kwargs = dict() @@ -159,17 +161,26 @@ def __init__( if image_widget_kwargs is None: image_widget_kwargs = dict() + # default kwargs unless user has specified more + default_iw_kwargs = { + "window_funcs": {"t": (np.mean, 11)}, + "vmin_vmax_sliders": True, + "cmap": "gnuplot2" + } + + image_widget_kwargs = { + **default_iw_kwargs, + **image_widget_kwargs # anything in default gets replaced with user-specified entries if present + } + self.input_movie_kwargs = input_movie_kwargs self.image_widget_kwargs = image_widget_kwargs - self._reset_timepoint_on_change = reset_timepoint_on_change + self.reset_timepoint_on_change = reset_timepoint_on_change self.image_widget: ImageWidget = None - self._image_widget_wrapper: ImageWidgetWrapper = None self.current_row: int = start_index - # set the initial widget state with the start index - self._make_image_widget(index=start_index) self._set_params_text_area(index=start_index) # set initial selected row @@ -184,43 +195,73 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") - def _set_mean_window_size(self, change): - self.image_widget.window_funcs = {"t": (np.mean, change["new"])} - self.image_widget.current_index = self.image_widget.current_index - - def _set_mean_diff(self, change): - # TODO: will do later - pass + # set the initial widget state with the start index + data_arrays = self._get_row_data(index=start_index) - def _make_image_widget(self, index): - self._image_widget_wrapper = ImageWidgetWrapper( - data=self._data, - data_mapping=get_mcorr_data_mapping(self._dataframe.iloc[index]), - image_widget_managed_data=image_widget_managed, - reset_timepoint_on_change=self._reset_timepoint_on_change, - input_movie_kwargs=self.input_movie_kwargs, - image_widget_kwargs=self.image_widget_kwargs + self.image_widget = ImageWidget( + data=data_arrays, + names=self._data_options, + **self.image_widget_kwargs ) - self.image_widget = self._image_widget_wrapper.image_widget - # mean window slider - self._slider_mean_window = IntSlider( + self.slider_mean_window = IntSlider( min=1, step=2, max=99, - value=33, + value=self.image_widget.window_funcs["t"].window_size, # set from the image widget description="mean wind", description_tooltip="set a mean rolling window" ) - self._slider_mean_window.observe(self._set_mean_window_size, "value") + self.slider_mean_window.observe(self._set_mean_window_size, "value") # TODO: mean diff checkbox - #self._checkbox_mean_diff + # self._checkbox_mean_diff + + self.sidecar = None - self.image_widget.window_funcs = {"t": (np.mean, self._slider_mean_window.value)} + def _set_mean_window_size(self, change): + self.image_widget.window_funcs = {"t": (np.mean, change["new"])} - def _get_selection_row(self) -> Union[int, None]: + # set same index, forces ImageWidget to run process_indices() so the image shown updates using the new window + self.image_widget.current_index = self.image_widget.current_index + + def _set_mean_diff(self, change): + # TODO: will do later + pass + + def _get_row_data(self, index: int) -> List[np.ndarray]: + data_arrays: List[np.ndarray] = list() + + data_mapping = get_mcorr_data_mapping(self._dataframe.iloc[index]) + + mcorr = data_mapping["mcorr"]() + + shape = mcorr.shape + ndim = mcorr.ndim + size = mcorr.size + + # go through all data options user has chosen + for option in self._data_options: + func = data_mapping[option] + + if option == "input": + # kwargs, such as using a specific input movie loader + array = func(**self.input_movie_kwargs) + + else: + # just fetch the array + array = func() + + # for 2D images + if array.ndim == 2: + array = DummyMovie(array, shape=shape, ndim=ndim, size=size) + + data_arrays.append(array) + + return data_arrays + + def _get_selected_row(self) -> Union[int, None]: r1 = self.datagrid.selections[0]["r1"] r2 = self.datagrid.selections[0]["r2"] @@ -235,32 +276,32 @@ def _get_selection_row(self) -> Union[int, None]: return index def _row_changed(self, *args): - index = self._get_selection_row() + index = self._get_selected_row() if index is None: return if self.current_row == index: return - if self.image_widget is None: - self._make_image_widget(index) - return - try: - self._image_widget_wrapper.change_data( - data=self._data, - data_mapping=get_mcorr_data_mapping(self._dataframe.iloc[index]), - input_movie_kwargs=self.input_movie_kwargs - ) + # fetch the data for this row + data_arrays = self._get_row_data(index) + except Exception as e: self.params_text_area.value = f"{type(e).__name__}\n" \ f"{str(e)}\n\n" \ f"See jupyter log for details" raise e - self._set_params_text_area(index) - - self.current_row = index + else: + # no exceptions, set ImageWidget + self.image_widget.set_data( + new_data=data_arrays, + reset_vmin_vmax=False, + reset_indices=self.reset_timepoint_on_change + ) + self._set_params_text_area(index) + self.current_row = index def _set_params_text_area(self, index): row = self._dataframe.iloc[index] @@ -269,9 +310,9 @@ def _set_params_text_area(self, index): param_diffs = self._dataframe.caiman.get_params_diffs( algo=row["algo"], item_name=row["item_name"] - ).iloc[index] + ).loc[index] - diffs_dict = {"diffs": param_diffs} + diffs_dict = {"diffs": param_diffs.to_dict()} diffs = f"{format_params(diffs_dict, 0)}\n\n" except: diffs = "" @@ -284,11 +325,17 @@ def show(self): Show the widget """ - return VBox([ - HBox([self.datagrid, self.params_text_area]), - self.image_widget.show(), - self._slider_mean_window - ]) + if self.sidecar is None: + self.sidecar = Sidecar() + + self.image_widget.reset_vmin_vmax() + + with self.sidecar: + return display(VBox([ + HBox([self.datagrid, self.params_text_area]), + self.image_widget.show(sidecar=False), + self.slider_mean_window + ])) @pd.api.extensions.register_dataframe_accessor("mcorr") @@ -298,7 +345,7 @@ def __init__(self, df): def viz( self, - data: List[str] = None, + data_options: List[str] = None, start_index: int = 0, reset_timepoint_on_change: bool = False, input_movie_kwargs=None, @@ -310,8 +357,8 @@ def viz( Parameters ---------- - data: list of str, default ["input", "mcorr", "mean", "corr"] - list of data to plot, valid options are: + data_options: list of str, default ["input", "mcorr", "mean", "corr"] + list of data options to plot, valid options are: +-------------+-------------------------------------+ | data option | description | @@ -345,9 +392,15 @@ def viz( McorrVizContainer widget that contains the DataGrid, params text box and ImageWidget """ + + for d in data_options: + if d not in VALID_DATA_OPTIONS: + raise KeyError(f"Invalid data option: \"{d}\", valid options are:" + f"\n{VALID_DATA_OPTIONS}") + container = McorrVizContainer( dataframe=self._dataframe, - data=data, + data_options=data_options, start_index=start_index, reset_timepoint_on_change=reset_timepoint_on_change, input_movie_kwargs=input_movie_kwargs, @@ -356,111 +409,3 @@ def viz( ) return container - - -@pd.api.extensions.register_series_accessor("mcorr") -class MCorrExtensionsViz(MCorrExtensions): - @property - def _data_mapping(self) -> Dict[str, callable]: - projections = {k: partial(self._series.caiman.get_projection, k) for k in projs} - m = { - "input": self._series.caiman.get_input_movie, - "mcorr": self.get_output, - "corr": self._series.caiman.get_corr_image, - **projections - } - return m - - @property - def _zero_array(self): - mcorr = self.get_output() - return ZeroArray(ndim=mcorr.ndim, n_frames=mcorr.shape[0]) - - @validate_algo("mcorr") - @validate_data_options() - def viz( - self, - data: List[str] = None, - input_movie_kwargs: dict = None, - image_widget_kwargs: dict = None, - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data: list of str, default ["input", "mcorr"] - list of data to plot, can also be a list of lists. - - input_movie_kwargs: dict, optional - kwargs passed to get_input_movie() - - image_widget_kwargs: dict, optional - kwargs passed to ImageWidget - - Returns - ------- - ImageWidget - fastplotlib.ImageWidget visualization - """ - - if data is None: - # default viz - data = ["input", "mcorr"] - - if input_movie_kwargs is None: - input_movie_kwargs = dict() - - if image_widget_kwargs is None: - image_widget_kwargs = dict() - - # data arrays directly passed to image widget - data_arrays_iw = list() - - for d in data: - if d in image_widget_managed: - func = self._data_mapping[d] - - if d == "input": - a = func(**input_movie_kwargs) - else: - a = func() - - data_arrays_iw.append(a) - - else: - # make a placeholder array to keep imagewidget happy - # hacky but this is the best way for now - zero_array = self._zero_array - data_arrays_iw.append(zero_array) - - # default kwargs unless user has specified - default_iw_kwargs = { - "window_funcs": {"t": (np.mean, 11)}, - "vmin_vmax_sliders": True, - "cmap": "gnuplot2" - } - - image_widget_kwargs = { - **default_iw_kwargs, - **image_widget_kwargs - } - - iw = ImageWidget( - data=data_arrays_iw, - names=data, - **image_widget_kwargs - ) - - for a, n in zip(data_arrays_iw, data): - if isinstance(a, ZeroArray): - # rename the existing graphic - iw.plot[n].graphics[0].name = "zero-array-ignore" - # get the real data - func = self._data_mapping[n] - real_data = func() - # create graphic with the real data, this will not be managed by ImageWidget - iw.plot[n].add_image(real_data, name="img", cmap="gnuplot2") - - iw.show() - return iw diff --git a/mesmerize_viz/_utils.py b/mesmerize_viz/_utils.py index 9c04c0f..0390daf 100644 --- a/mesmerize_viz/_utils.py +++ b/mesmerize_viz/_utils.py @@ -19,24 +19,23 @@ def validate_data_options(): def dec(func): @wraps(func) def wrapper(self, *args, **kwargs): - if "data" in kwargs: - data = kwargs["data"] + if "data_options" in kwargs: + data_options = kwargs["data_options"] else: if len(args) > 0: - data = args[0] + data_options = args[0] else: # assume the extension func will take care of it # the default data arg is None is nothing is passed return func(self, *args, **kwargs) - # flatten - if any([isinstance(d, (list, tuple)) for d in data]): - data = list(chain.from_iterable(data)) + if any([isinstance(d, (list, tuple)) for d in data_options]): + data_options = list(chain.from_iterable(data_options)) valid_options = list(self._data_mapping.keys()) - for d in data: + for d in data_options: if d not in valid_options: raise KeyError(f"Invalid data option: \"{d}\", valid options are:" f"\n{valid_options}") @@ -47,37 +46,50 @@ def wrapper(self, *args, **kwargs): return dec -class ZeroArray(LazyArray): - """ - This array is used as placeholders to allow mixing data of different ndims in the ImageWidget. - For example this allows having mean, max etc. projections in the same ImageWidget as the - input or mcorr movie. It also allows having LineStacks or Heatmap in the same ImageWidget. - """ - def __init__(self, ndim): - self._shape = [1] * ndim - self.rval = np.zeros(shape=self.shape, dtype=np.int8) - # hack to allow it to work with any other array sizes - self._shape[0] = np.inf - - @property - def dtype(self) -> str: - return "int8" - - @property - def shape(self) -> Tuple[int, int, int]: - return tuple(self._shape) - - @property - def n_frames(self) -> int: - return np.inf - - @property - def min(self) -> float: - return 0.0 - - @property - def max(self) -> float: - return 0.0 - - def _compute_at_indices(self, indices: Union[int, slice]) -> np.ndarray: - return self.rval +class DummyMovie: + """Really really hacky""" + def __init__(self, image: np.ndarray, shape, ndim, size): + self.image = image + self.shape = shape + self.ndim = ndim + self.size = size + + def __getitem__(self, index: Union[int, slice]): + if isinstance(index, tuple): + for s in index: + if isinstance(s, int): + # assumption + index = s + break + + if (s.start is None) and (s.stop is None) and (s.step is None): + continue + else: + # assume that this is the dimension that user has asked for, and we return the image using + # slice size from this dimension + index = s + + if isinstance(index, (slice, range)): + start, stop, step = index.start, index.stop, index.step + + if start is None: + start = 0 + + if stop is None: + # assumption, again this is very hacky + stop = max(self.shape) + + if step is None: + step = 1 + + r = range(start, stop, step) + + n_frames = len(r) + + return np.array([self.image] * n_frames) + + if isinstance(index, int): + return self.image + + else: + raise TypeError(f"DummyMovie only accept int or slice indexing, you have passed: {index}") From 5e5a3dcb36768fdf237f23100a2a5d86a103b51a Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 13 Oct 2023 22:39:52 -0400 Subject: [PATCH 08/36] fix infinite recursion when clicking components, primitive component color setter --- mesmerize_viz/_cnmf/_viz_container.py | 172 +++++++++++++++++++++++++- mesmerize_viz/_cnmf/_wrapper.py | 104 ++-------------- 2 files changed, 178 insertions(+), 98 deletions(-) diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index 1621520..fb1fb6b 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -1,5 +1,6 @@ import itertools from _warnings import warn +from functools import partial from typing import * import pandas as pd @@ -7,11 +8,86 @@ from ipywidgets import Textarea, Layout, VBox, HBox from IPython.display import display from sidecar import Sidecar +import numpy as np -from ._wrapper import VALID_DATA_OPTIONS, get_cnmf_data_mapping, GridPlotWrapper +from ._wrapper import VALID_DATA_OPTIONS, GridPlotWrapper, projs, ExtensionCallWrapper, TEMPORAL_OPTIONS from .._utils import format_params +def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_data_loaders: dict = None) -> dict: + """ + Returns dict that maps data option str to a callable that can return the corresponding data array. + + For example, ``{"input": series.get_input_movie}`` maps "input" -> series.get_input_movie + + Parameters + ---------- + series: pd.Series + row/item to get mapping from + + data_kwargs: dict, optional + optional kwargs for each of the extension functions + + other_data_loaders: dict + {"data_option": callable}, example {"behavior": LazyVideo} + + Returns + ------- + dict + {data label: callable} + """ + if data_kwargs is None: + data_kwargs = dict() + + if other_data_loaders is None: + other_data_loaders = dict() + + default_extension_kwargs = {k: dict() for k in VALID_DATA_OPTIONS + list(other_data_loaders.keys())} + + default_extension_kwargs["contours"] = {"swap_dim": False} + + ext_kwargs = { + **default_extension_kwargs, + **data_kwargs + } + + projections = {k: partial(series.caiman.get_projection, k) for k in projs} + + other_data_loaders_mapping = dict() + + # make ExtensionCallWrapers for other data loaders + for option in list(other_data_loaders.keys()): + other_data_loaders_mapping[option] = ExtensionCallWrapper(other_data_loaders[option], ext_kwargs[option]) + + rcm_rcb_projs = dict() + for proj in ["mean", "min", "max", "std"]: + rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( + series.cnmf.get_rcm, + ext_kwargs["rcm"], + attr=f"{proj}_image" + ) + + temporal_mappings = { + k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS + } + + m = { + "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), + "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), + "rcb": ExtensionCallWrapper(series.cnmf.get_rcb, ext_kwargs["rcb"]), + "residuals": ExtensionCallWrapper(series.cnmf.get_residuals, ext_kwargs["residuals"]), + "corr": ExtensionCallWrapper(series.caiman.get_corr_image, ext_kwargs["corr"]), + "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), + "empty": None, + **temporal_mappings, + **projections, + **rcm_rcb_projs, + **other_data_loaders_mapping + } + + return m + + class CNMFVizContainer: """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" @@ -284,3 +360,97 @@ def _set_params_text_area(self, index): # diffs and full params self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) + + + @property + def cmap(self) -> str: + return self._gridplot_wrapper.cmap + + @cmap.setter + def cmap(self, cmap: str): + for g in self._gridplot_wrapper.image_graphics: + g.cmap = cmap + + def set_component_colors( + self, + colors: Union[str, np.ndarray], + cmap: str = None, + visible: str = "all" + ): + """ + + Parameters + ---------- + colors: str or np.ndarray + np.ndarray or one of: random, accepted, rejected, accepted-rejected, snr_comps, snr_comps_log, + r_values, cnn_preds + + If np.ndarray, it must be of the same length as the number of components + + cmap: str + custom cmap for the colors + + visible: str + one of: all, accepted, rejected + + Returns + ------- + + """ + if colors == "random": + for contours in self._gridplot_wrapper.contour_graphics: + contours[:].colors = "random" + + cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() + + n_contours = len(self._gridplot_wrapper.contour_graphics[0]) + + if colors in ["accepted", "rejected", "accepted-rejected"]: + if cmap is None: + cmap = "Set1" + + # make a empty array for cmap_values + classifier = np.zeros(n_contours, dtype=int) + # set the accepted components to 1 + classifier[cnmf_obj.estimates.idx_components] = 1 + + else: + if cmap is None: + cmap = "spring" + + elif colors == "snr_comps": + classifier = cnmf_obj.estimates.SNR_comp + + elif colors == "snr_comps_log": + classifier = np.log10(cnmf_obj.estimates.SNR_comp) + + elif colors == "r_values": + classifier = cnmf_obj.estimates.r_values + + elif colors == "cnn_preds": + classifier = cnmf_obj.estimates.cnn_preds + + elif isinstance(colors, np.ndarray): + if not colors.size == n_contours: + raise ValueError(f"If using np.ndarray cor component_colors, the array size must be " + f"the same as n_contours: {n_contours}, your array size is: {colors.size}") + + classifier = colors + + else: + raise ValueError("Invalid component_colors value") + + for contours in self._gridplot_wrapper.contour_graphics: + contours.cmap = cmap + contours.cmap_values = classifier + + # choose to make all or accepted or rejected visible + if visible == "accepted": + contours[cnmf_obj.estimates.idx_components_bad].colors[:, -1] = 0 + + elif visible == "rejected": + contours[cnmf_obj.estimates.idx_components].colors[:, -1] = 0 + + else: + # make everything visible + contours[:].colors[:, -1] = 1 diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 94ffd4f..1b16012 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -1,14 +1,12 @@ -from functools import partial from itertools import product from typing import Union, List, Dict import numpy as np -import pandas as pd from ipywidgets import IntSlider, BoundedIntText, jslink from fastplotlib import GridPlot, graphics -from fastplotlib.graphics.selectors import LinearSelector, Synchronizer, LinearRegionSelector +from fastplotlib.graphics.selectors import LinearSelector, Synchronizer from fastplotlib.utils import calculate_gridshape @@ -89,80 +87,6 @@ def __call__(self, *args, **kwargs): return rval -def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_data_loaders: dict = None) -> dict: - """ - Returns dict that maps data option str to a callable that can return the corresponding data array. - - For example, ``{"input": series.get_input_movie}`` maps "input" -> series.get_input_movie - - Parameters - ---------- - series: pd.Series - row/item to get mapping from - - data_kwargs: dict, optional - optional kwargs for each of the extension functions - - other_data_loaders: dict - {"data_option": callable}, example {"behavior": LazyVideo} - - Returns - ------- - dict - {data label: callable} - """ - if data_kwargs is None: - data_kwargs = dict() - - if other_data_loaders is None: - other_data_loaders = dict() - - default_extension_kwargs = {k: dict() for k in VALID_DATA_OPTIONS + list(other_data_loaders.keys())} - - default_extension_kwargs["contours"] = {"swap_dim": False} - - ext_kwargs = { - **default_extension_kwargs, - **data_kwargs - } - - projections = {k: partial(series.caiman.get_projection, k) for k in projs} - - other_data_loaders_mapping = dict() - - # make ExtensionCallWrapers for other data loaders - for option in list(other_data_loaders.keys()): - other_data_loaders_mapping[option] = ExtensionCallWrapper(other_data_loaders[option], ext_kwargs[option]) - - rcm_rcb_projs = dict() - for proj in ["mean", "min", "max", "std"]: - rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( - series.cnmf.get_rcm, - ext_kwargs["rcm"], - attr=f"{proj}_image" - ) - - temporal_mappings = { - k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS - } - - m = { - "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), - "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), - "rcb": ExtensionCallWrapper(series.cnmf.get_rcb, ext_kwargs["rcb"]), - "residuals": ExtensionCallWrapper(series.cnmf.get_residuals, ext_kwargs["residuals"]), - "corr": ExtensionCallWrapper(series.caiman.get_corr_image, ext_kwargs["corr"]), - "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), - "empty": None, - **temporal_mappings, - **projections, - **rcm_rcb_projs, - **other_data_loaders_mapping - } - - return m - - class GridPlotWrapper: """Wraps GridPlot in a way that allows updating the data""" @@ -212,7 +136,7 @@ def __init__( if gridplot_kwargs is None: gridplot_kwargs = dict() - self._cmap = cmap + self.cmap = cmap self.component_colors = component_colors @@ -281,6 +205,11 @@ def set_component_index(self, index: int): def _heatmap_set_component_index(self, ev): index = ev.pick_info["selected_index"] + if ev.pick_info["pygfx_event"] is None: + # this means that the selector was not triggered by the user but that it moved due to another event + # so we don't set_component_index because then infinite recursion + return + self.set_component_index(index) def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: @@ -303,25 +232,6 @@ def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: return data_arrays - @property - def cmap(self) -> str: - return self._cmap - - @cmap.setter - def cmap(self, cmap: str): - for g in self.image_graphics: - g.cmap = cmap - - # @property - # def component_colors(self) -> Any: - # pass - # - # @component_colors.setter - # def component_colors(self, colors: Any): - # for collection in self.contour_graphics: - # for g in collection.graphics: - # - def change_data(self, data_mapping: Dict[str, callable]): """ Changes the data shown in the gridplot. From 6f9128df32cea0cd671d0d722216604e5509f6d9 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 02:36:33 -0400 Subject: [PATCH 09/36] contour color handling and zoom components works :D --- mesmerize_viz/_cnmf/_viz_container.py | 118 ++++++++++++++++++++------ mesmerize_viz/_cnmf/_wrapper.py | 50 ++++++++++- 2 files changed, 138 insertions(+), 30 deletions(-) diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index fb1fb6b..5fa25a4 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -5,7 +5,7 @@ import pandas as pd from ipydatagrid import DataGrid -from ipywidgets import Textarea, Layout, VBox, HBox +from ipywidgets import Textarea, Layout, VBox, HBox, RadioButtons, Dropdown, FloatSlider from IPython.display import display from sidecar import Sidecar import numpy as np @@ -224,6 +224,8 @@ def __init__( self.current_row: int = start_index + self._random_colors = None + self._make_gridplot( start_index=start_index, reset_timepoint_on_change=reset_timepoint_on_change, @@ -247,6 +249,39 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") + self._dropdown_contour_colors = Dropdown( + options=["random", "accepted", "rejected", "snr_comps", "snr_comps_log", "r_values", "cnn_preds"], + value="random", + description='contour colors:', + ) + + self._dropdown_contour_colors.observe(self._ipywidget_set_component_colors, "value") + + self._radio_visible_components = RadioButtons( + options=["all", "accepted", "rejected"], + description_tooltip="contours to make visible", + description="visible contours" + ) + + self._radio_visible_components.observe(self._ipywidget_set_component_colors, "value") + + self._spinbox_alpha_invisible_contours = FloatSlider( + value=0.0, + min=0.0, + max=1.0, + step=0.1, + description="invisible alpha:", + description_tooltip="transparency of contours set to be invisible", + disabled=False + ) + + self._spinbox_alpha_invisible_contours.observe(self._ipywidget_set_component_colors, "value") + + self._box_contour_controls = VBox([ + self._dropdown_contour_colors, + HBox([self._radio_visible_components, self._spinbox_alpha_invisible_contours]) + ]) + self.sidecar = None def _make_gridplot( @@ -264,6 +299,12 @@ def _make_gridplot( self.data_kwargs ) + cnmf_obj = self._dataframe.iloc[start_index].cnmf.get_output() + n_contours = cnmf_obj.estimates.C.shape[0] + + self._random_colors = np.random.rand(n_contours, 4).astype(np.float32) + self._random_colors[:, -1] = 1 + self._gridplot_wrapper = GridPlotWrapper( data=self._data, data_mapping=data_mapping, @@ -284,17 +325,20 @@ def show(self, sidecar: bool = True): gridplots_widget = [gp.show(sidecar=False) for gp in self.gridplots] if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: - vbox_elements = gridplots_widget + gridplot_elements = gridplots_widget else: - vbox_elements = list() + gridplot_elements = list() if self.sidecar is None: self.sidecar = Sidecar() widget = VBox( [ - HBox([self.datagrid, self.params_text_area]), HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), - VBox(vbox_elements) + HBox([self.datagrid, self.params_text_area]), + HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), + VBox(gridplot_elements), + HBox([self._gridplot_wrapper.checkbox_zoom_components, self._gridplot_wrapper.zoom_components_scale]), + self._box_contour_controls ] ) @@ -342,6 +386,12 @@ def _row_changed(self, *args): self._set_params_text_area(index) + cnmf_obj = self._dataframe.iloc[index].cnmf.get_output() + n_contours = cnmf_obj.estimates.C.shape[0] + + self._random_colors = np.random.rand(n_contours, 4).astype(np.float32) + self._random_colors[:, -1] = 1 + self.current_row = index def _set_params_text_area(self, index): @@ -371,11 +421,30 @@ def cmap(self, cmap: str): for g in self._gridplot_wrapper.image_graphics: g.cmap = cmap + def _set_component_visibility(self, contours, cnmf_obj): + visible = self._radio_visible_components.value + alpha_invisible = self._spinbox_alpha_invisible_contours.value + + # choose to make all or accepted or rejected visible + if visible == "accepted": + contours[cnmf_obj.estimates.idx_components_bad].colors[:, -1] = alpha_invisible + + elif visible == "rejected": + contours[cnmf_obj.estimates.idx_components].colors[:, -1] = alpha_invisible + + else: + # make everything visible + contours[:].colors[:, -1] = 1 + + def _ipywidget_set_component_colors(self, *args): + """just a wrapper to make ipywidgets happy""" + colors = self._dropdown_contour_colors.value + self.set_component_colors(colors) + def set_component_colors( self, colors: Union[str, np.ndarray], cmap: str = None, - visible: str = "all" ): """ @@ -390,22 +459,24 @@ def set_component_colors( cmap: str custom cmap for the colors - visible: str - one of: all, accepted, rejected - Returns ------- """ + cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() + n_contours = len(self._gridplot_wrapper.contour_graphics[0]) + if colors == "random": + colors = self._random_colors for contours in self._gridplot_wrapper.contour_graphics: - contours[:].colors = "random" + for i, g in enumerate(contours.graphics): + g.colors = colors[i] - cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() + self._set_component_visibility(contours, cnmf_obj) - n_contours = len(self._gridplot_wrapper.contour_graphics[0]) + return - if colors in ["accepted", "rejected", "accepted-rejected"]: + if colors in ["accepted", "rejected"]: if cmap is None: cmap = "Set1" @@ -418,7 +489,7 @@ def set_component_colors( if cmap is None: cmap = "spring" - elif colors == "snr_comps": + if colors == "snr_comps": classifier = cnmf_obj.estimates.SNR_comp elif colors == "snr_comps_log": @@ -438,19 +509,14 @@ def set_component_colors( classifier = colors else: - raise ValueError("Invalid component_colors value") + raise ValueError("Invalid colors value") for contours in self._gridplot_wrapper.contour_graphics: - contours.cmap = cmap - contours.cmap_values = classifier + # first initialize using a quantitative cmap + # this ensures that setting cmap_values will work + contours.cmap = "gray" - # choose to make all or accepted or rejected visible - if visible == "accepted": - contours[cnmf_obj.estimates.idx_components_bad].colors[:, -1] = 0 - - elif visible == "rejected": - contours[cnmf_obj.estimates.idx_components].colors[:, -1] = 0 + contours.cmap_values = classifier + contours.cmap = cmap - else: - # make everything visible - contours[:].colors[:, -1] = 1 + self._set_component_visibility(contours, cnmf_obj) diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 1b16012..4fe296c 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -3,7 +3,7 @@ import numpy as np -from ipywidgets import IntSlider, BoundedIntText, jslink +from ipywidgets import IntSlider, BoundedIntText, jslink, Checkbox, FloatSlider from fastplotlib import GridPlot, graphics from fastplotlib.graphics.selectors import LinearSelector, Synchronizer @@ -157,9 +157,20 @@ def __init__( # gridplot for each sublist for sub_data in self._data: - _gridplot_kwargs = {"shape": calculate_gridshape(len(sub_data))} - _gridplot_kwargs.update(gridplot_kwargs) - self.gridplots.append(GridPlot(**_gridplot_kwargs)) + # make the kwargs + final_gridplot_kwargs = { + "shape": calculate_gridshape(len(sub_data)), + "controllers": "sync" + } + # merge with any use-specified kwargs + # user-specified kwargs will override anything specified here + + final_gridplot_kwargs.update(gridplot_kwargs) + + # instantiate gridplot and add to list of gridplots + self.gridplots.append( + GridPlot(**final_gridplot_kwargs) + ) self.temporal_graphics: List[graphics.LineGraphic] = list() self.temporal_stack_graphics: List[graphics.LineStack] = list() @@ -185,8 +196,37 @@ def __init__( self._current_temporal_components: np.ndarray = None + self.checkbox_zoom_components = Checkbox( + value=True, + description="auto-zoom component", + description_tooltip="If checked, zoom into selected component" + ) + + self.zoom_components_scale = FloatSlider( + min=0.25, + max=3, + value=1, + step=0.25, + description="zoom scale", + description_tooltip="scale if zoom components is checked" + ) + self.change_data(data_mapping) + def _zoom_into_component(self, index: int): + if not self.checkbox_zoom_components.value: + return + + for gridplot in self.gridplots: + for subplot in gridplot: + if "contours" not in subplot: + continue + + subplot.camera.show_object( + subplot["contours"].graphics[index].world_object, + scale=self.zoom_components_scale.value + ) + def set_component_index(self, index: int): # TODO: more elegant way than skip_heatmap for g in self.contour_graphics: @@ -202,6 +242,8 @@ def set_component_index(self, index: int): self.component_int_box.value = index + self._zoom_into_component(index) + def _heatmap_set_component_index(self, ev): index = ev.pick_info["selected_index"] From ecc3b4facbc8015c6bc8709eacb908d1c810b879 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 02:50:01 -0400 Subject: [PATCH 10/36] disable params viewer --- mesmerize_viz/_cnmf/_viz_container.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index 5fa25a4..5d3fe1e 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -211,7 +211,8 @@ def __init__( height="250px", max_height="250px", width="360px", - max_width="500px" + max_width="500px", + disabled=True, ) # data options is private since this can't be changed once an image widget has been made From 272bd34e0506e5ef1759ed25e51a92fa23188d56 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 02:50:43 -0400 Subject: [PATCH 11/36] disable mcorr params viewer --- mesmerize_viz/_mcorr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 863ff11..30e294b 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -149,7 +149,8 @@ def __init__( height="250px", max_height="250px", width="360px", - max_width="500px" + max_width="500px", + disabled=True, ) # data options is private since this can't be changed once an image widget has been made From d48cbf13b331d68e5698bd3fc5505c8abdea35eb Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 03:48:29 -0400 Subject: [PATCH 12/36] add norm, zscore, dfof --- mesmerize_viz/_cnmf/_extensions.py | 15 ++++++--- mesmerize_viz/_cnmf/_viz_container.py | 25 ++++++++++++++- mesmerize_viz/_cnmf/_wrapper.py | 44 +++++++++++++++++++++------ 3 files changed, 69 insertions(+), 15 deletions(-) diff --git a/mesmerize_viz/_cnmf/_extensions.py b/mesmerize_viz/_cnmf/_extensions.py index fc837b7..bb87ebe 100644 --- a/mesmerize_viz/_cnmf/_extensions.py +++ b/mesmerize_viz/_cnmf/_extensions.py @@ -12,7 +12,7 @@ def __init__(self, df): def viz( self, - data: List[str] = None, + data_options: List[str] = None, start_index: int = 0, reset_timepoint_on_change: bool = False, data_graphic_kwargs: dict = None, @@ -29,13 +29,18 @@ def viz( Parameters ---------- - data: list of str or list of list of str - default [["temporal"], ["input", "rcm", "rcb", "residuals"]] + data_options: list of str or list of list of str + default [["temporal"], ["heatmap-norm"], ["input", "rcm", "rcb", "residuals"]] + + **Note:** You may add suffixes to temporal and heatmap options for "dfof", "zscore", "norm", + examples: "temporal-dfof", "heatmap-norm", "heatmap-zscore", "heatmap-dfof", etc. + list of data to plot, valid options are: + +------------------+-----------------------------------------+ + | data option | description | +------------------+-----------------------------------------+ | "input" | input movie | - +==================+=========================================+ | "rcm" | reconstructed movie, A * C | | "rcb" | reconstructed background, b * f | | "residuals" | residuals, input - (A * C) - (b * f) | @@ -74,7 +79,7 @@ def viz( """ container = CNMFVizContainer( dataframe=self._dataframe, - data=data, + data=data_options, start_index=start_index, reset_timepoint_on_change=reset_timepoint_on_change, data_graphic_kwargs=data_graphic_kwargs, diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index 5d3fe1e..07c7a77 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -9,8 +9,12 @@ from IPython.display import display from sidecar import Sidecar import numpy as np +from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax -from ._wrapper import VALID_DATA_OPTIONS, GridPlotWrapper, projs, ExtensionCallWrapper, TEMPORAL_OPTIONS +from ._wrapper import ( + VALID_DATA_OPTIONS, GridPlotWrapper, projs, ExtensionCallWrapper, TEMPORAL_OPTIONS, + TEMPORAL_OPTIONS_DFOF, TEMPORAL_OPTIONS_ZSCORE, TEMPORAL_OPTIONS_NORM +) from .._utils import format_params @@ -71,6 +75,22 @@ def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_dat k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS } + dfof_mappings = { + k: ExtensionCallWrapper(series.cnmf.get_detrend_dfof, ext_kwargs[k]) for k in TEMPORAL_OPTIONS_DFOF + } + + zscore_mappings = { + k: ExtensionCallWrapper( + series.cnmf.get_temporal, ext_kwargs[k], post_process_func=TimeSeriesScalerMeanVariance().fit_transform + ) for k in TEMPORAL_OPTIONS_ZSCORE + } + + norm_mappings = { + k: ExtensionCallWrapper( + series.cnmf.get_temporal, ext_kwargs[k], post_process_func=TimeSeriesScalerMinMax().fit_transform + ) for k in TEMPORAL_OPTIONS_NORM + } + m = { "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), @@ -80,6 +100,9 @@ def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_dat "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), "empty": None, **temporal_mappings, + **dfof_mappings, + **zscore_mappings, + **norm_mappings, **projections, **rcm_rcb_projs, **other_data_loaders_mapping diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 4fe296c..4ca7384 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -2,8 +2,7 @@ from typing import Union, List, Dict import numpy as np - -from ipywidgets import IntSlider, BoundedIntText, jslink, Checkbox, FloatSlider +from ipywidgets import IntSlider, BoundedIntText, jslink, Checkbox, FloatSlider, RadioButtons from fastplotlib import GridPlot, graphics from fastplotlib.graphics.selectors import LinearSelector, Synchronizer @@ -40,6 +39,20 @@ "heatmap", ] +TEMPORAL_OPTIONS_DFOF = [ + f"{option}-dfof" for option in TEMPORAL_OPTIONS +] + +TEMPORAL_OPTIONS_ZSCORE = [ + f"{option}-zscore" for option in TEMPORAL_OPTIONS +] + +TEMPORAL_OPTIONS_NORM = [ + f"{option}-norm" for option in TEMPORAL_OPTIONS +] + +TEMPORAL_OPTIONS_ALL = TEMPORAL_OPTIONS + TEMPORAL_OPTIONS_DFOF + TEMPORAL_OPTIONS_ZSCORE + TEMPORAL_OPTIONS_ZSCORE + TEMPORAL_OPTIONS_NORM + projs = [ "mean", "max", @@ -49,11 +62,17 @@ IMAGE_OPTIONS += projs VALID_DATA_OPTIONS += IMAGE_OPTIONS -VALID_DATA_OPTIONS += TEMPORAL_OPTIONS +VALID_DATA_OPTIONS += TEMPORAL_OPTIONS_ALL class ExtensionCallWrapper: - def __init__(self, extension_func: callable, kwargs: dict = None, attr: str = None): + def __init__( + self, + extension_func: callable, + kwargs: dict = None, + attr: str = None, + post_process_func: callable = None, + ): """ Basically like ``functools.partial`` but supports kwargs. @@ -65,9 +84,12 @@ def __init__(self, extension_func: callable, kwargs: dict = None, attr: str = No kwargs: dict kwargs to pass to the extension function when it is called - attr: str, optional + attr: str, optionalself, extension_func: callable, kwargs: dict = None, attr: str = None return an attribute of the callable's output instead of the return value of the callable. Example: if using rcm, can set ``attr="max_image"`` to return the max proj of the RCM. + + post_process_func: callable + A function to postprocess before returning, such as zscore, etc. """ if kwargs is None: @@ -77,6 +99,7 @@ def __init__(self, extension_func: callable, kwargs: dict = None, attr: str = No self.func = extension_func self.attr = attr + self.post_process_func = post_process_func def __call__(self, *args, **kwargs): rval = self.func(**self.kwargs) @@ -84,6 +107,9 @@ def __call__(self, *args, **kwargs): if self.attr is not None: return getattr(rval, self.attr) + if self.post_process_func is not None: + return self.post_process_func(rval) + return rval @@ -374,7 +400,7 @@ def _change_data_gridplot( if data_option == "empty": continue - elif data_option == "temporal": + elif data_option.startswith("temporal") and "stack" not in data_option: # Only few one line at a time current_graphic = subplot.add_line( data_array[0], @@ -395,7 +421,7 @@ def _change_data_gridplot( # scale according to temporal dims subplot.camera.maintain_aspect = False - elif data_option == "temporal-stack": + elif data_option.startswith("temporal-stack"): current_graphic = subplot.add_line_stack( data_array, colors=component_colors, @@ -407,7 +433,7 @@ def _change_data_gridplot( # scale according to temporal dims subplot.camera.maintain_aspect = False - elif data_option == "heatmap": + elif data_option.startswith("heatmap"): current_graphic = subplot.add_heatmap( data_array, name="heatmap", @@ -452,7 +478,7 @@ def _change_data_gridplot( subplot.name = data_option - if data_option in TEMPORAL_OPTIONS: + if data_option in TEMPORAL_OPTIONS_ALL: self.linear_selectors.append(current_graphic.add_linear_selector()) subplot.camera.maintain_aspect = False From 6a8b9dcbafb674a9551dbd30fcc679580f0134d3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 03:53:26 -0400 Subject: [PATCH 13/36] update examples --- examples/cnmf.ipynb | 93 ++++++++++++++-------------------------- examples/mcorr.ipynb | 100 ++++++++++++------------------------------- 2 files changed, 58 insertions(+), 135 deletions(-) diff --git a/examples/cnmf.ipynb b/examples/cnmf.ipynb index 402350e..e5daca9 100644 --- a/examples/cnmf.ipynb +++ b/examples/cnmf.ipynb @@ -12,13 +12,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-06-11 07:34:01.073007: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-06-11 07:34:01.095009: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-06-11 07:34:01.095386: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2023-10-14 03:49:11.570098: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2023-10-14 03:49:11.599719: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2023-10-14 03:49:11.599748: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2023-10-14 03:49:11.599774: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2023-10-14 03:49:11.605509: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2023-10-14 03:49:11.606166: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-06-11 07:34:01.605728: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_mcorr.py:336: UserWarning: registration of accessor under name 'mcorr' for type is overriding a preexisting attribute with the same name.\n", - " @pd.api.extensions.register_series_accessor(\"mcorr\")\n" + "2023-10-14 03:49:12.496589: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "Possible incompatible version of wgpu:\n", + " Detected 0.11.0, need >=0.10.0, <0.11.0.\n" ] } ], @@ -55,14 +58,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:445: UserWarning: Index name of 'index' is not round-trippable.\n", + "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:460: UserWarning: Index name of 'index' is not round-trippable.\n", " schema = pd.io.json.build_table_schema(dataframe)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cb892b243f1e4ff1b8d3bea0fe890e9c", + "model_id": "5081da17f0024189a932a8a1e41b2892", "version_major": 2, "version_minor": 0 }, @@ -76,7 +79,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8836f855be97406d8f2246bd5a1e932f", + "model_id": "52f6f09085584836b3cbd5afb51e7139", "version_major": 2, "version_minor": 0 }, @@ -88,34 +91,26 @@ "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Decode mmap filename /home/kushal/caiman_data/mesmerize-batch/b3503ab5-6376-4cd4-8c33-7f80735fca82/b3503ab5-6376-4cd4-8c33-7f80735fca82-Sue_2x_3000_40_-46_els__d1_170_d2_170_d3_1_order_F_frames_3000.mmap\n", - "Decode mmap filename /home/kushal/caiman_data/mesmerize-batch/b3503ab5-6376-4cd4-8c33-7f80735fca82/b3503ab5-6376-4cd4-8c33-7f80735fca82-Sue_2x_3000_40_-46_els__d1_170_d2_170_d3_1_order_F_frames_3000.mmap\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "edf7f52fd7bc4b378fb43c4c33c9a854", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:34: UserWarning: converting float64 array to float32\n", + "/home/kushal/repos/fastplotlib/fastplotlib/graphics/_features/_base.py:34: UserWarning: converting float64 array to float32\n", " warn(f\"converting {array.dtype} array to float32\")\n", - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:123: UserWarning: Event handler > is already registered.\n", - " warn(f\"Event handler {handler} is already registered.\")\n", - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:123: UserWarning: Event handler Graphics> is already registered.\n", - " warn(f\"Event handler {handler} is already registered.\")\n", - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:123: UserWarning: Event handler Graphics> is already registered.\n", - " warn(f\"Event handler {handler} is already registered.\")\n", - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:123: UserWarning: Event handler Graphics> is already registered.\n", - " warn(f\"Event handler {handler} is already registered.\")\n", - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:123: UserWarning: Event handler Graphics> is already registered.\n", - " warn(f\"Event handler {handler} is already registered.\")\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_cnmf.py:854: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", + "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_cnmf/_viz_container.py:425: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", "CaimanDataFrameExtensions.get_params_diffs\n", "This feature is new and the might improve in the future\n", "\n", @@ -124,44 +119,18 @@ } ], "source": [ - "container_widget = df.cnmf.viz(start_index=1)" + "container_widget = df.cnmf.viz(data_options=[[\"temporal\"], [\"heatmap-zscore\"], [\"input\", \"rcm\"]], start_index=1)" ] }, { "cell_type": "code", "execution_count": 4, - "id": "96636ad0-cde7-437d-a1ce-609287f6d005", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "19a4952086494afa9089e11a1bf2eace", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(HBox(children=(DataGrid(auto_fit_params={'area': 'all', 'padding': 30, 'numCols': None}, base_r…" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "id": "d603672a-f43c-42f3-bf9a-de4499e27659", + "metadata": {}, + "outputs": [], "source": [ "container_widget.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e0f395a-17c3-45ae-9fe6-adef0dd4ac44", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -180,7 +149,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.2" } }, "nbformat": 4, diff --git a/examples/mcorr.ipynb b/examples/mcorr.ipynb index 5aec15c..55f9230 100644 --- a/examples/mcorr.ipynb +++ b/examples/mcorr.ipynb @@ -20,13 +20,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-06-11 06:52:59.943298: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-06-11 06:52:59.965901: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-06-11 06:52:59.966418: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2023-10-14 03:49:01.457161: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2023-10-14 03:49:01.489296: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2023-10-14 03:49:01.489325: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2023-10-14 03:49:01.489351: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2023-10-14 03:49:01.495237: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2023-10-14 03:49:01.496002: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-06-11 06:53:00.520465: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_mcorr.py:336: UserWarning: registration of accessor under name 'mcorr' for type is overriding a preexisting attribute with the same name.\n", - " @pd.api.extensions.register_series_accessor(\"mcorr\")\n" + "2023-10-14 03:49:02.279430: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "Possible incompatible version of wgpu:\n", + " Detected 0.11.0, need >=0.10.0, <0.11.0.\n" ] } ], @@ -47,38 +50,22 @@ "set_parent_raw_data_path(\"/home/kushal/caiman_data/\")\n", "\n", "batch_path = \"/home/kushal/caiman_data/mesmerize-batch/batch.pickle\"\n", - "\n", "df = load_batch(batch_path)" ] }, { "cell_type": "code", "execution_count": 3, - "id": "c5991539-b585-45d4-869f-5495e21af8bd", - "metadata": { - "tags": [] - }, + "id": "562fb85c-aedc-4d7e-8810-7846fe74e47f", + "metadata": {}, "outputs": [], "source": [ - "from mesmerize_viz._mcorr import get_mcorr_data_mapping\n", - "from mesmerize_viz._common import ImageWidgetWrapper" + "from mesmerize_viz._utils import format_params" ] }, { "cell_type": "code", "execution_count": 4, - "id": "c341a6fc-f802-4c42-8ba8-a3055e1e5a3e", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "mapping = get_mcorr_data_mapping(df.iloc[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, "id": "7b5d4c32-c5e7-42ed-8e39-55fcda337bcc", "metadata": { "tags": [] @@ -88,21 +75,19 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:445: UserWarning: Index name of 'index' is not round-trippable.\n", - " schema = pd.io.json.build_table_schema(dataframe)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Decode mmap filename /home/kushal/caiman_data/mesmerize-batch/b3503ab5-6376-4cd4-8c33-7f80735fca82/b3503ab5-6376-4cd4-8c33-7f80735fca82-Sue_2x_3000_40_-46_els__d1_170_d2_170_d3_1_order_F_frames_3000.mmap\n" + "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:460: UserWarning: Index name of 'index' is not round-trippable.\n", + " schema = pd.io.json.build_table_schema(dataframe)\n", + "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_mcorr.py:311: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", + "CaimanDataFrameExtensions.get_params_diffs\n", + "This feature is new and the might improve in the future\n", + "\n", + " param_diffs = self._dataframe.caiman.get_params_diffs(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b1ed1f20cc944cf495048afaaf503f7a", + "model_id": "b5fa0c34225f441581d8fd223e60a894", "version_major": 2, "version_minor": 0 }, @@ -112,49 +97,18 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/features/_base.py:34: UserWarning: converting float64 array to float32\n", - " warn(f\"converting {array.dtype} array to float32\")\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_mcorr.py:245: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", - "CaimanDataFrameExtensions.get_params_diffs\n", - "This feature is new and the might improve in the future\n", - "\n", - " param_diffs = self._dataframe.caiman.get_params_diffs(\n" - ] } ], "source": [ - "container_widget = df.mcorr.viz(data=[\"input\", \"mcorr\", \"mean\", \"corr\"])" + "container_widget = df.mcorr.viz(data_options=[\"input\", \"mcorr\", \"mean\", \"corr\"])" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "de1c7796-1362-4808-a902-1b7f08d23ba8", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4241832411b14e42a6217d7b79c4f8bc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(HBox(children=(DataGrid(auto_fit_params={'area': 'all', 'padding': 30, 'numCols': None}, base_r…" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": 5, + "id": "c7fac51e-5d91-4204-af25-4fce40ae20d5", + "metadata": {}, + "outputs": [], "source": [ "container_widget.show()" ] @@ -162,7 +116,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6660b91-40e8-4b04-8bd3-89d774324d10", + "id": "50355864-53e1-44e4-898a-6499a061c0fd", "metadata": {}, "outputs": [], "source": [] @@ -184,7 +138,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.2" } }, "nbformat": 4, From 13a2312e95556faf7e76d32a196008a7127ee9ae Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Oct 2023 04:09:48 -0400 Subject: [PATCH 14/36] modify setup.py --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 8efa235..ab75695 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,9 @@ install_requires = [ "mesmerize-core", - "fastplotlib", - "ipydatagrid" - + "fastplotlib[notebook]>=v0.1.0.a14", + "ipydatagrid", + "tslearn", ] From 2324c7d5de6277849cbde74481fe4fdc877a30d0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 16 Oct 2023 20:12:16 -0400 Subject: [PATCH 15/36] sidecar optional for mcorr viz --- mesmerize_viz/_mcorr.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 30e294b..3d57c3b 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -321,7 +321,7 @@ def _set_params_text_area(self, index): # diffs and full params self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) - def show(self): + def show(self, sidecar: bool = True): """ Show the widget """ @@ -331,12 +331,17 @@ def show(self): self.image_widget.reset_vmin_vmax() - with self.sidecar: - return display(VBox([ + widget = VBox([ HBox([self.datagrid, self.params_text_area]), self.image_widget.show(sidecar=False), self.slider_mean_window - ])) + ]) + + if not sidecar: + return widget + + with self.sidecar: + return display(widget) @pd.api.extensions.register_dataframe_accessor("mcorr") From ef4bf2deddce78373534c7089222fab51544427a Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 25 Oct 2023 22:36:57 -0400 Subject: [PATCH 16/36] cnmf works well --- mesmerize_viz/_cnmf/_extensions.py | 6 ++--- mesmerize_viz/_cnmf/_viz_container.py | 36 +++++++++++++++++++++------ mesmerize_viz/_cnmf/_wrapper.py | 30 +++++++++++----------- mesmerize_viz/_mcorr.py | 8 ++++-- 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/mesmerize_viz/_cnmf/_extensions.py b/mesmerize_viz/_cnmf/_extensions.py index bb87ebe..0cef1d0 100644 --- a/mesmerize_viz/_cnmf/_extensions.py +++ b/mesmerize_viz/_cnmf/_extensions.py @@ -13,7 +13,7 @@ def __init__(self, df): def viz( self, data_options: List[str] = None, - start_index: int = 0, + start_index: int = None, reset_timepoint_on_change: bool = False, data_graphic_kwargs: dict = None, gridplot_kwargs: dict = None, @@ -29,8 +29,8 @@ def viz( Parameters ---------- - data_options: list of str or list of list of str - default [["temporal"], ["heatmap-norm"], ["input", "rcm", "rcb", "residuals"]] + data_options: list of list of str + default [["temporal"], ["heatmap-zscore"], ["input", "rcm", "rcb", "residuals"]] **Note:** You may add suffixes to temporal and heatmap options for "dfof", "zscore", "norm", examples: "temporal-dfof", "heatmap-norm", "heatmap-zscore", "heatmap-dfof", etc. diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py index 07c7a77..fed0c5e 100644 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ b/mesmerize_viz/_cnmf/_viz_container.py @@ -118,7 +118,7 @@ def __init__( self, dataframe: pd.DataFrame, data: List[str] = None, - start_index: int = 0, + start_index: int = None, reset_timepoint_on_change: bool = False, data_graphic_kwargs: dict = None, gridplot_kwargs: dict = None, @@ -136,7 +136,7 @@ def __init__( ---------- dataframe: pd.DataFrame - data: list of str + data: list of str, or list of list of str data options, such as "input", "temporal", "contours", etc. start_index @@ -161,7 +161,11 @@ def __init__( """ if data is None: - data = [["temporal"], ["input", "rcm", "rcb", "residuals"]] + data = [["temporal"], ["heatmap-zscore"], ["input", "rcm", "rcb", "residuals"]] + # if it's the default options, it will hstack the temporal and heatmap next to the image data + self.default = True + else: + self.default = False if other_data_loaders is None: other_data_loaders = dict() @@ -246,6 +250,10 @@ def __init__( self.data_kwargs = data_kwargs + if start_index is None: + # try to guess the start index + start_index = dataframe[dataframe.algo == "cnmf"].iloc[0].name + self.current_row: int = start_index self._random_colors = None @@ -346,11 +354,25 @@ def show(self, sidecar: bool = True): """Show the widget""" # create gridplots and start render loop - gridplots_widget = [gp.show(sidecar=False) for gp in self.gridplots] + gridplots = [gp.show(sidecar=False) for gp in self.gridplots] + + # contour color controls and auto-zoom + contour_controls = VBox( + [ + HBox([self._gridplot_wrapper.checkbox_zoom_components, self._gridplot_wrapper.zoom_components_scale]), + self._box_contour_controls + ] + ) if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: - gridplot_elements = gridplots_widget + if self.default: + # TODO: let's just make this the mandatory behavior, temporal + heatmap on left, any image stuff on right + # temporal and heatmap on left side, image data on right side + gridplot_elements = HBox([VBox(gridplots[:2]), VBox([gridplots[2], contour_controls])]) + else: + gridplot_elements = VBox(gridplots) else: + raise NotImplemented("show() not implemented outside of jupyter") gridplot_elements = list() if self.sidecar is None: @@ -360,9 +382,7 @@ def show(self, sidecar: bool = True): [ HBox([self.datagrid, self.params_text_area]), HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), - VBox(gridplot_elements), - HBox([self._gridplot_wrapper.checkbox_zoom_components, self._gridplot_wrapper.zoom_components_scale]), - self._box_contour_controls + gridplot_elements, ] ) diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py index 4ca7384..67b5f3d 100644 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ b/mesmerize_viz/_cnmf/_wrapper.py @@ -74,7 +74,12 @@ def __init__( post_process_func: callable = None, ): """ - Basically like ``functools.partial`` but supports kwargs. + Basically a very fancy ``functools.partial``. + + In addition to behaving like ``functools.partial``, it supports: + - kwargs + - returning attributes of the return value from the callable + - postprocessing the return value Parameters ---------- @@ -204,7 +209,7 @@ def __init__( self.image_graphics: List[graphics.ImageGraphic] = list() self.contour_graphics: List[graphics.LineCollection] = list() - self.heatmap_selectors: List[LinearSelector] = list() + self.heatmap_component_ix_selectors: List[LinearSelector] = list() # selects heatmap rows, i.e. components self._managed_graphics: List[list] = [ self.temporal_graphics, @@ -216,7 +221,9 @@ def __init__( # to store only image data in a 1:1 mapping to the graphics list self.image_graphic_arrays: List[np.ndarray] = list() - self.linear_selectors: List[LinearSelector] = list() + self.linear_selectors: List[LinearSelector] = list() # select current timepoint, i.e. frame index + + self._synchronizer: Synchronizer = Synchronizer(key_bind=None) # synchronizes linear_selectors self._current_frame_index: int = 0 @@ -261,7 +268,7 @@ def set_component_index(self, index: int): for g in self.temporal_graphics: g.data = self._current_temporal_components[index] - for s in self.heatmap_selectors: + for s in self.heatmap_component_ix_selectors: # TODO: Very hacky for now, ignores if the slider is currently being moved, prevents weird slider movement if s._move_info is None: s.selection = index @@ -317,7 +324,8 @@ def change_data(self, data_mapping: Dict[str, callable]): for l in self._managed_graphics: l.clear() - self.heatmap_selectors.clear() + self._synchronizer.clear() # must clear synchronizer first before the selectors, else lingering weakrefs + self.heatmap_component_ix_selectors.clear() self.linear_selectors.clear() self.image_graphic_arrays.clear() @@ -352,14 +360,11 @@ def change_data(self, data_mapping: Dict[str, callable]): # connect events self._connect_events() - # sync sliders if multiple are present - if len(self.linear_selectors) > 0: - self._synchronizer = Synchronizer(*self.linear_selectors, key_bind=None) - for ls in self.linear_selectors: ls.selection.add_event_handler(self.set_frame_index) + self._synchronizer.add(ls) # sync linear_selectors - for hs in self.heatmap_selectors: + for hs in self.heatmap_component_ix_selectors: hs.selection.add_event_handler(self._heatmap_set_component_index) def _change_data_gridplot( @@ -450,7 +455,7 @@ def _change_data_gridplot( thickness=5, ) - self.heatmap_selectors.append(selector) + self.heatmap_component_ix_selectors.append(selector) else: # else it is an image @@ -517,9 +522,6 @@ def _connect_events(self): contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=5) - # for temporal_graphic in self.temporal_graphics: - # contour_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True) - for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 3d57c3b..3666bc9 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -63,7 +63,7 @@ def __init__( self, dataframe: pd.DataFrame, data_options: List[str] = None, - start_index: int = 0, + start_index: int = None, reset_timepoint_on_change: bool = False, input_movie_kwargs: dict = None, image_widget_kwargs: dict = None, @@ -89,7 +89,7 @@ def __init__( | pnr | peak-noise-ratio image, if computed | +-------------+-------------------------------------+ - start_index: int, default 0 + start_index: int start index item used to set the initial data in the ImageWidget reset_timepoint_on_change: bool, default False @@ -180,6 +180,10 @@ def __init__( self.reset_timepoint_on_change = reset_timepoint_on_change self.image_widget: ImageWidget = None + # try to guess the start index + if start_index is None: + start_index = dataframe[dataframe.algo == "mcorr"].iloc[0].name + self.current_row: int = start_index self._set_params_text_area(index=start_index) From 6633cd43fc4ba64e698ef18c5041f536aec6b8e8 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 28 Oct 2023 18:04:46 -0400 Subject: [PATCH 17/36] update w.r.t. latest fpl --- mesmerize_viz/_mcorr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index 3666bc9..a15c4eb 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -165,7 +165,6 @@ def __init__( # default kwargs unless user has specified more default_iw_kwargs = { "window_funcs": {"t": (np.mean, 11)}, - "vmin_vmax_sliders": True, "cmap": "gnuplot2" } From c50fb8839cefcd4d8de3f9f9860cab98b256c477 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 30 Oct 2023 22:01:06 -0400 Subject: [PATCH 18/36] mcorr works with qt --- mesmerize_viz/_mcorr.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/mesmerize_viz/_mcorr.py b/mesmerize_viz/_mcorr.py index a15c4eb..9a54c68 100644 --- a/mesmerize_viz/_mcorr.py +++ b/mesmerize_viz/_mcorr.py @@ -58,6 +58,9 @@ def get_mcorr_data_mapping(series: pd.Series) -> dict: class McorrVizContainer: """Widget that contains the DataGrid, params text box and ImageWidget""" + @property + def widget(self): + return self._widget def __init__( self, @@ -223,6 +226,7 @@ def __init__( # self._checkbox_mean_diff self.sidecar = None + self._widget = None def _set_mean_window_size(self, change): self.image_widget.window_funcs = {"t": (np.mean, change["new"])} @@ -329,22 +333,31 @@ def show(self, sidecar: bool = True): Show the widget """ - if self.sidecar is None: - self.sidecar = Sidecar() - self.image_widget.reset_vmin_vmax() - widget = VBox([ - HBox([self.datagrid, self.params_text_area]), - self.image_widget.show(sidecar=False), - self.slider_mean_window - ]) + datagrid_params = HBox([self.datagrid, self.params_text_area]) + + if self.image_widget.gridplot.canvas.__class__.__name__ == "JupyterWgpuCanvas": + self._widget = VBox([ + datagrid_params, + self.image_widget.show(sidecar=False), + self.slider_mean_window + ]) + + if not sidecar: + return self.widget + + if self.sidecar is None: + self.sidecar = Sidecar() - if not sidecar: - return widget + with self.sidecar: + return display(self.widget) - with self.sidecar: - return display(widget) + elif self.image_widget.gridplot.canvas.__class__.__name__ == "QWgpuCanvas": + # shown the image widget in Qt window + self.image_widget.show() + # return datagrid to show in jupyter + return datagrid_params @pd.api.extensions.register_dataframe_accessor("mcorr") From 042ae83db37ba7bbfd13b7c3ed0a087ff6c79b6d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 31 Oct 2023 03:23:40 -0400 Subject: [PATCH 19/36] star rewrite of cnmf using image widget --- mesmerize_viz/_cnmf.py | 474 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 474 insertions(+) create mode 100644 mesmerize_viz/_cnmf.py diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py new file mode 100644 index 0000000..27f85a1 --- /dev/null +++ b/mesmerize_viz/_cnmf.py @@ -0,0 +1,474 @@ +from functools import partial +from typing import * +from warnings import warn + + +import numpy as np +import pandas as pd +from ipydatagrid import DataGrid +from ipywidgets import Textarea, Layout, HBox, VBox +from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax +import fastplotlib as fpl + + +from ._utils import DummyMovie + + +IMAGE_OPTIONS = [ + "input", + "rcm", + "rcb", + "residuals", + "corr", + "pnr", +] + +rcm_rcb_proj_options = list() +# RCM and RCB projections +for option in ["rcm", "rcb"]: + for proj in ["mean", "min", "max", "std"]: + rcm_rcb_proj_options.append(f"{option}-{proj}") + + +IMAGE_OPTIONS += rcm_rcb_proj_options + +PROJS = [ + "mean", + "max", + "std", +] + +IMAGE_OPTIONS += PROJS + + +class ExtensionCallWrapper: + def __init__( + self, + extension_func: callable, + kwargs: dict = None, + attr: str = None, + post_process_func: callable = None, + ): + """ + Basically a very fancy ``functools.partial``. + + In addition to behaving like ``functools.partial``, it supports: + - kwargs + - returning attributes of the return value from the callable + - postprocessing the return value + + Parameters + ---------- + extension_func: callable + extension function reference + + kwargs: dict + kwargs to pass to the extension function when it is called + + attr: str, optionalself, extension_func: callable, kwargs: dict = None, attr: str = None + return an attribute of the callable's output instead of the return value of the callable. + Example: if using rcm, can set ``attr="max_image"`` to return the max proj of the RCM. + + post_process_func: callable + A function to postprocess before returning, such as zscore, etc. + """ + + if kwargs is None: + self.kwargs = dict() + else: + self.kwargs = kwargs + + self.func = extension_func + self.attr = attr + self.post_process_func = post_process_func + + def __call__(self, *args, **kwargs): + rval = self.func(**self.kwargs) + + if self.attr is not None: + return getattr(rval, self.attr) + + if self.post_process_func is not None: + return self.post_process_func(rval) + + return rval + + +def get_cnmf_data_mapping( + series: pd.Series, + input_movie_kwargs: dict, + temporal_kwargs: dict, +): + projections = {k: partial(series.caiman.get_projection, k) for k in PROJS} + + rcm_rcb_projs = dict() + for proj in ["mean", "min", "max", "std"]: + rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( + series.cnmf.get_rcm, + attr=f"{proj}_image" + ) + + zscore_func = TimeSeriesScalerMeanVariance().fit_transform + norm_func = TimeSeriesScalerMinMax().fit_transform + + temporal_mappings = { + "temporal": ExtensionCallWrapper(series.cnmf.get_temporal, temporal_kwargs), + "zscore": ExtensionCallWrapper(series.cnmf.get_temporal, temporal_kwargs, post_process_func=zscore_func), + "norm": ExtensionCallWrapper(series.cnmf.get_temporal, temporal_kwargs, post_process_func=norm_func), + "dfof": partial(series.cnmf.get_detrend_dfof), + "dfof-zscore": ExtensionCallWrapper(series.cnmf.get_detrend_dfof, post_process_func=zscore_func), + "dfof-norm": ExtensionCallWrapper(series.cnmf.get_detrend_dfof, post_process_func=zscore_func) + } + + mapping = { + "input": ExtensionCallWrapper(series.caiman.get_input_movie, input_movie_kwargs), + "rcm": series.cnmf.get_rcm, + "rcb": series.cnmf.get_rcb, + "residuals": series.cnmf.get_residuals, + "corr": series.caiman.get_corr_image, + "pnr": series.caiman.get_pnr_image, + "contours": ExtensionCallWrapper(series.cnmf.get_contours, {"swap_dim": False}), + **projections, + **rcm_rcb_projs, + **temporal_mappings, + } + + return mapping + + +class CNMFVizContainer: + """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" + + def __init__( + self, + dataframe: pd.DataFrame, + start_index: int = None, + temporal_data_option: str = None, + image_data_options: list[str] = None, + temporal_kwargs: dict = None, + reset_timepoint_on_change: bool = False, + input_movie_kwargs: dict = None, + image_widget_kwargs=None, + data_grid_kwargs: dict = None, + ): + """ + Visualize CNMF output and other data columns such as behavior video (optional). + + Note: If using dfof temporal_data_option, you must have already run dfof. + + Parameters + ---------- + dataframe: pd.DataFrame + + start_index: int + + temporal_data_option: optional, str + if not provided or ``None`: uses cnmf.get_temporal() + + if zscore: uses zscore of cnmf.get_temporal() + + if norm: uses 0-1 normalized output of cnmf.get_temporal() + + if dfof: uses cnmf.get_dfof() + + if dfof-zscore: uses cnmf.get_dfof() and then zscores + + if dfof-norm: uses cnmf.get_dfof() and then 0-1 normalizes + + reset_timepoint_on_change: bool + + temporal_postprocess: optional, list of str or callable + + heatmap_postprocess: str, None, callable + if str: one of "norm", "dfof", "zscore" + Or a callable to postprocess using your own function + + temporal_kwargs: dict + kwargs passed to cnmf.get_temporal(), example: {"add_residuals" : True}. + Ignored if temporal_data_option contains "dfof" + + input_movie_kwargs: dict + kwargs passed to caiman.get_input() + + data_grid_kwargs + """ + + self._dataframe = dataframe + + valid_temporal_options = [ + "temporal", + "zscore", + "norm", + "dfof", + "dfof-zscore", + "dfof-norm" + ] + + if temporal_data_option is None: + temporal_data_option = "temporal" + + if temporal_data_option not in valid_temporal_options: + raise ValueError( + f"You have passed the following invalid temporal option: {temporal_data_option}\n" + f"Valid options are:\n" + f"{valid_temporal_options}" + ) + + if image_data_options is None: + image_data_options = [ + "input", + "rcm", + "rcb", + "residuals" + ] + + for option in image_data_options: + if option not in IMAGE_OPTIONS: + raise ValueError( + f"Invalid image option passed, valid image options are:\n" + f"{IMAGE_OPTIONS}" + ) + + self.image_data_options = image_data_options + + self.temporal_data_option = temporal_data_option + self.temporal_kwargs = temporal_kwargs + + if self.temporal_kwargs is None: + self.temporal_kwargs = dict() + + # for now we will force all components, accepted and rejected, to be shown + if "component_indices" in self.temporal_kwargs.keys(): + raise ValueError( + "The kwarg `component_indices` is not allowed here." + ) + + self.reset_timepoint_on_change = reset_timepoint_on_change + self.input_movie_kwargs = input_movie_kwargs + + default_widths = { + "algo": 50, + 'item_name': 200, + 'input_movie_path': 120, + 'algo_duration': 80, + 'comments': 120, + 'uuid': 60 + } + + columns = dataframe.columns + # these add clutter + hide_columns = [ + "params", + "outputs", + "added_time", + "ran_time", + + ] + + df_show = self._dataframe[[c for c in columns if c not in hide_columns]] + + if data_grid_kwargs is None: + data_grid_kwargs = dict() + + self.datagrid = DataGrid( + df_show, # show only a subset + selection_mode="cell", + layout={"height": "250px", "width": "750px"}, + base_row_size=24, + index_name="index", + column_widths=default_widths, + **data_grid_kwargs + ) + + self.params_text_area = Textarea() + self.params_text_area.layout = Layout( + height="250px", + max_height="250px", + width="360px", + max_width="500px", + disabled=True, + ) + + if image_widget_kwargs is None: + image_widget_kwargs = dict() + + default_image_widget_kwargs = { + "cmap": "gnuplot2" + } + + self.image_widget_kwargs = { + **default_image_widget_kwargs, + **image_widget_kwargs + } + + if start_index is None: + start_index = dataframe[dataframe.algo == "cnmf"].iloc[0].name + + self.current_row: int = start_index + + self.datagrid.select( + row1=start_index, + column1=0, + row2=start_index, + column2=len(df_show.columns), + clear_mode="all" + ) + + # callback when row changed + self.datagrid.observe(self._row_changed, names="selections") + + self._plot_temporal = fpl.Plot() + self._plot_temporal.camera.maintain_aspect = False + self._plot_heatmap = fpl.Plot() + self._plot_heatmap.camera.maintain_aspect = False + + self._image_widget: fpl.ImageWidget = None + + self._synchronizer = fpl.Synchronizer(key_bind=None) + + data_arrays = self._get_row_data(index=start_index) + self._set_data(data_arrays) + + def _get_selected_row(self) -> Union[int, None]: + r1 = self.datagrid.selections[0]["r1"] + r2 = self.datagrid.selections[0]["r2"] + + if r1 != r2: + warn("Only single row selection is currently allowed") + return + + # get corresponding dataframe index from currently visible dataframe + # since filtering etc. is possible + index = self.datagrid.get_visible_data().index[r1] + + return index + + def _get_row_data(self, index: int) -> Dict[str, np.ndarray]: + data_mapping = get_cnmf_data_mapping( + series=self._dataframe.iloc[index], + input_movie_kwargs=self.input_movie_kwargs, + temporal_kwargs=self.temporal_kwargs + ) + + temporal = data_mapping[self.temporal_data_option]() + + rcm = data_mapping["rcm"]() + + shape = rcm.shape + ndim = rcm.ndim + size = rcm.shape[0] * rcm.shape[1] * rcm.shape[2] + + images = list() + for option in self.image_data_options: + array = data_mapping[option]() + + if array.ndim == 2: # for 2D images, to make ImageWidget happy + array = DummyMovie(array, shape=shape, ndim=ndim, size=size) + + images.append(array) + + data_arrays = { + "temporal": temporal, + "images": images, + } + + return data_arrays + + def _row_changed(self, *args): + index = self._get_selected_row() + if index is None: + return + + if self.current_row == index: + return + + try: + data_arrays = self._get_row_data(index) + + except Exception as e: + self.params_text_area.value = f"{type(e).__name__}\n" \ + f"{str(e)}\n\n" \ + f"See jupyter log for details" + raise e + + else: + # no exceptions, set plots + self._set_data(data_arrays) + + def _set_data(self, data_arrays: Dict[str, np.ndarray]): + # self._contour_graphics.clear() + + self._plot_temporal.clear() + self._plot_heatmap.clear() + + self._synchronizer.clear() + + if self._image_widget is None: + self._image_widget = fpl.ImageWidget( + data=data_arrays["images"], + **self.image_widget_kwargs + ) + + # need to start it here so that we can access the toolbar to link events with the slider + self._image_widget.show() + + else: + self._image_widget.set_data(data_arrays["images"]) + + temporal = data_arrays["temporal"] + + self._plot_temporal.add_line(temporal[0], name="line") + self._plot_heatmap.add_heatmap(temporal, name="heatmap") + + self._linear_selector_temporal: fpl.LinearSelector = self._plot_temporal["line"].add_linear_selector() + self._linear_selector_temporal.selection.add_event_handler(self._set_frame_index_from_linear_selector) + + self._linear_selector_heatmap: fpl.LinearSelector = self._plot_heatmap["heatmap"].add_linear_selector() + + # sync the linear selectors + self._synchronizer.add(self._linear_selector_temporal) + self._synchronizer.add(self._linear_selector_heatmap) + + # absolute garbage monkey patch which I will fix once we make ImageWidget emit its own events + if hasattr(self._image_widget.sliders["t"], "qslider"): + self._image_widget.sliders["t"].qslider.valueChanged.connect(self._set_linear_selector_index_from_image_widget) + else: + # ipywidget + self._image_widget.sliders["t"].observe(self._set_linear_selector_index_from_image_widget, "value") + + def _set_frame_index_from_linear_selector(self, ev): + # TODO: hacky mess, need to make ImageWidget emit events + ix = ev.pick_info["selected_index"] + self._image_widget.sliders["t"].value = ix + + def _set_linear_selector_index_from_image_widget(self, ev): + if isinstance(ev, dict): + # ipywidget + ix = ev["new"] + + # else it's directly from Qt slider + else: + ix = ev + + self._linear_selector_temporal.selection = ix + + def show(self, sidecar: bool = False): + """ + Show the widget + + Parameters + ---------- + sidecar + + Returns + ------- + + """ + + datagrid_params = HBox([self.datagrid, self.params_text_area]) + + temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) + + plots = HBox([temporals, self._image_widget.widget]) + + return VBox([datagrid_params, plots]) From 768f4a54ff4a9e6fecaed1b251944af904921ebb Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 31 Oct 2023 03:44:44 -0400 Subject: [PATCH 20/36] add click contour events --- mesmerize_viz/_cnmf.py | 76 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 27f85a1..b60b2c3 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -229,7 +229,7 @@ def __init__( f"{IMAGE_OPTIONS}" ) - self.image_data_options = image_data_options + self._image_data_options = image_data_options self.temporal_data_option = temporal_data_option self.temporal_kwargs = temporal_kwargs @@ -326,6 +326,8 @@ def __init__( self._synchronizer = fpl.Synchronizer(key_bind=None) + self._contour_graphics: List[fpl.LineCollection] = list() + data_arrays = self._get_row_data(index=start_index) self._set_data(data_arrays) @@ -359,7 +361,7 @@ def _get_row_data(self, index: int) -> Dict[str, np.ndarray]: size = rcm.shape[0] * rcm.shape[1] * rcm.shape[2] images = list() - for option in self.image_data_options: + for option in self._image_data_options: array = data_mapping[option]() if array.ndim == 2: # for 2D images, to make ImageWidget happy @@ -367,9 +369,12 @@ def _get_row_data(self, index: int) -> Dict[str, np.ndarray]: images.append(array) + contours = data_mapping["contours"]() + data_arrays = { "temporal": temporal, "images": images, + "contours": contours, } return data_arrays @@ -396,7 +401,7 @@ def _row_changed(self, *args): self._set_data(data_arrays) def _set_data(self, data_arrays: Dict[str, np.ndarray]): - # self._contour_graphics.clear() + self._contour_graphics.clear() self._plot_temporal.clear() self._plot_heatmap.clear() @@ -406,6 +411,7 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): if self._image_widget is None: self._image_widget = fpl.ImageWidget( data=data_arrays["images"], + names=self._image_data_options, **self.image_widget_kwargs ) @@ -415,11 +421,13 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): else: self._image_widget.set_data(data_arrays["images"]) - temporal = data_arrays["temporal"] + self._temporal_data = data_arrays["temporal"] - self._plot_temporal.add_line(temporal[0], name="line") - self._plot_heatmap.add_heatmap(temporal, name="heatmap") + # make temporal graphics + self._plot_temporal.add_line(self._temporal_data[0], name="line") + self._plot_heatmap.add_heatmap(self._temporal_data, name="heatmap") + # linear selectors and events self._linear_selector_temporal: fpl.LinearSelector = self._plot_temporal["line"].add_linear_selector() self._linear_selector_temporal.selection.add_event_handler(self._set_frame_index_from_linear_selector) @@ -436,6 +444,61 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): # ipywidget self._image_widget.sliders["t"].observe(self._set_linear_selector_index_from_image_widget, "value") + contours = data_arrays["contours"][0] + + n_components = len(contours) + component_colors = np.random.rand(n_components, 4).astype(np.float32) + component_colors[:, -1] = 1 + + for subplot in self._image_widget.gridplot: + contour_graphic = subplot.add_line_collection( + contours, + colors=component_colors, + name="contours" + ) + self._contour_graphics.append(contour_graphic) + + image_graphic = subplot["image_widget_managed"] + + image_graphic.link( + "click", + target=contour_graphic, + feature="colors", + new_data="w", + callback=self._euclidean + ) + + contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=2) + + def _euclidean(self, source, target, event, new_data): + """maps click events to contour""" + # calculate coms of line collection + indices = np.array(event.pick_info["index"]) + + coms = list() + + for contour in target.graphics: + coors = contour.data()[~np.isnan(contour.data()).any(axis=1)] + com = coors.mean(axis=0) + coms.append(com) + + # euclidean distance to find closest index of com + indices = np.append(indices, [0]) + + ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) + + self.set_component_index(ix) + # + # self.component_int_box.value = ix + + return None + + def set_component_index(self, ix): + for g in self._contour_graphics: + g.set_feature(feature="colors", new_data="w", indices=ix) + + self._plot_temporal["line"].data = self._temporal_data[ix] + def _set_frame_index_from_linear_selector(self, ev): # TODO: hacky mess, need to make ImageWidget emit events ix = ev.pick_info["selected_index"] @@ -451,6 +514,7 @@ def _set_linear_selector_index_from_image_widget(self, ev): ix = ev self._linear_selector_temporal.selection = ix + self._linear_selector_heatmap.selection = ix def show(self, sidecar: bool = False): """ From 0a2713afed9e24c20daf2f7c73bee4fe0bd0d259 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 01:53:06 -0400 Subject: [PATCH 21/36] component selection --- mesmerize_viz/_cnmf.py | 85 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 74 insertions(+), 11 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index b60b2c3..a2554cf 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd from ipydatagrid import DataGrid -from ipywidgets import Textarea, Layout, HBox, VBox +from ipywidgets import Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax import fastplotlib as fpl @@ -317,6 +317,36 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") + self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") + self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) + for trait in ["value", "max"]: + jslink((self.component_slider, trait), (self.component_int_box, trait)) + + self.component_int_box.observe( + lambda change: self.set_component_index(change["new"]), "value" + ) + + self.checkbox_zoom_components = Checkbox( + value=True, + description="auto-zoom component", + description_tooltip="If checked, zoom into selected component" + ) + + self.zoom_components_scale = FloatSlider( + min=0.25, + max=3, + value=1, + step=0.25, + description="zoom scale", + description_tooltip="zoom scale as a factor of component width/height" + ) + + self._top_widget = VBox([ + HBox([self.datagrid, self.params_text_area]), + HBox([self.component_slider, self.component_int_box]), + HBox([self.checkbox_zoom_components, self.zoom_components_scale]) + ]) + self._plot_temporal = fpl.Plot() self._plot_temporal.camera.maintain_aspect = False self._plot_heatmap = fpl.Plot() @@ -403,11 +433,11 @@ def _row_changed(self, *args): def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._contour_graphics.clear() + self._synchronizer.clear() + self._plot_temporal.clear() self._plot_heatmap.clear() - self._synchronizer.clear() - if self._image_widget is None: self._image_widget = fpl.ImageWidget( data=data_arrays["images"], @@ -419,14 +449,24 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._image_widget.show() else: + # image widget doesn't need clear, we can just use set_data self._image_widget.set_data(data_arrays["images"]) + for subplot in self._image_widget.gridplot: + if "contours" in subplot: + # delete the contour graphics + subplot.delete_graphic(subplot["contours"]) self._temporal_data = data_arrays["temporal"] # make temporal graphics self._plot_temporal.add_line(self._temporal_data[0], name="line") + # autoscale the single temporal line plot when the data changes + self._plot_temporal["line"].data.add_event_handler(self._plot_temporal.auto_scale) self._plot_heatmap.add_heatmap(self._temporal_data, name="heatmap") + self._component_linear_selector: fpl.LinearSelector = self._plot_heatmap["heatmap"].add_linear_selector(axis="y", thickness=5) + self._component_linear_selector.selection.add_event_handler(self.set_component_index) + # linear selectors and events self._linear_selector_temporal: fpl.LinearSelector = self._plot_temporal["line"].add_linear_selector() self._linear_selector_temporal.selection.add_event_handler(self._set_frame_index_from_linear_selector) @@ -488,16 +528,41 @@ def _euclidean(self, source, target, event, new_data): ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) self.set_component_index(ix) - # - # self.component_int_box.value = ix + + self.component_int_box.value = ix return None - def set_component_index(self, ix): + def set_component_index(self, index): + if hasattr(index, "pick_info"): + # came from heatmap component selector + if index.pick_info["pygfx_event"] is None: + # this means that the selector was not triggered by the user but that it moved due to another event + # so then we don't set_component_index because then infinite recursion + return + index = index.pick_info["selected_index"] + for g in self._contour_graphics: - g.set_feature(feature="colors", new_data="w", indices=ix) + g.set_feature(feature="colors", new_data="w", indices=index) - self._plot_temporal["line"].data = self._temporal_data[ix] + self._plot_temporal["line"].data = self._temporal_data[index] + + if self._component_linear_selector._move_info is None: + # TODO: Very hacky for now, ignores if the slider is currently being moved by the user + # prevents weird slider movement + self._component_linear_selector.selection = index + + self._zoom_into_component(index) + + def _zoom_into_component(self, index: int): + if not self.checkbox_zoom_components.value: + return + + for subplot in self._image_widget.gridplot: + subplot.camera.show_object( + subplot["contours"].graphics[index].world_object, + scale=self.zoom_components_scale.value + ) def _set_frame_index_from_linear_selector(self, ev): # TODO: hacky mess, need to make ImageWidget emit events @@ -529,10 +594,8 @@ def show(self, sidecar: bool = False): """ - datagrid_params = HBox([self.datagrid, self.params_text_area]) - temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) plots = HBox([temporals, self._image_widget.widget]) - return VBox([datagrid_params, plots]) + return VBox([self._top_widget, plots]) From 586c1f5dec47017927bd7ec1323435312e5a2de9 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 02:17:39 -0400 Subject: [PATCH 22/36] contour colors --- mesmerize_viz/_cnmf.py | 152 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 6 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index a2554cf..4caae78 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -317,6 +317,7 @@ def __init__( # callback when row changed self.datagrid.observe(self._row_changed, names="selections") + # ipywidgets for selecting components self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) for trait in ["value", "max"]: @@ -326,12 +327,13 @@ def __init__( lambda change: self.set_component_index(change["new"]), "value" ) + # checkbox to zoom into components when selected self.checkbox_zoom_components = Checkbox( value=True, description="auto-zoom component", description_tooltip="If checked, zoom into selected component" ) - + # zoom factor self.zoom_components_scale = FloatSlider( min=0.25, max=3, @@ -340,13 +342,47 @@ def __init__( description="zoom scale", description_tooltip="zoom scale as a factor of component width/height" ) - + # organize these widgets to be shown at the top self._top_widget = VBox([ HBox([self.datagrid, self.params_text_area]), HBox([self.component_slider, self.component_int_box]), HBox([self.checkbox_zoom_components, self.zoom_components_scale]) ]) + self._dropdown_contour_colors = Dropdown( + options=["random", "accepted", "rejected", "snr_comps", "snr_comps_log", "r_values", "cnn_preds"], + value="random", + description='contour colors:', + ) + + self._dropdown_contour_colors.observe(self._ipywidget_set_component_colors, "value") + + self._radio_visible_components = RadioButtons( + options=["all", "accepted", "rejected"], + description_tooltip="contours to make visible", + description="visible contours" + ) + + self._radio_visible_components.observe(self._ipywidget_set_component_colors, "value") + + self._spinbox_alpha_invisible_contours = FloatSlider( + value=0.0, + min=0.0, + max=1.0, + step=0.1, + description="invisible alpha:", + description_tooltip="transparency of contours set to be invisible", + disabled=False + ) + + self._spinbox_alpha_invisible_contours.observe(self._ipywidget_set_component_colors, "value") + + self._box_contour_controls = VBox([ + self._dropdown_contour_colors, + HBox([self._radio_visible_components, self._spinbox_alpha_invisible_contours]) + ]) + + # plots self._plot_temporal = fpl.Plot() self._plot_temporal.camera.maintain_aspect = False self._plot_heatmap = fpl.Plot() @@ -487,13 +523,13 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): contours = data_arrays["contours"][0] n_components = len(contours) - component_colors = np.random.rand(n_components, 4).astype(np.float32) - component_colors[:, -1] = 1 + self._random_colors = np.random.rand(n_components, 4).astype(np.float32) + self._random_colors[:, -1] = 1 for subplot in self._image_widget.gridplot: contour_graphic = subplot.add_line_collection( contours, - colors=component_colors, + colors=self._random_colors, name="contours" ) self._contour_graphics.append(contour_graphic) @@ -581,6 +617,108 @@ def _set_linear_selector_index_from_image_widget(self, ev): self._linear_selector_temporal.selection = ix self._linear_selector_heatmap.selection = ix + def _ipywidget_set_component_colors(self, *args): + """just a wrapper to make ipywidgets happy""" + colors = self._dropdown_contour_colors.value + self.set_component_colors(colors) + + def set_component_colors( + self, + metric: Union[str, np.ndarray], + cmap: str = None, + ): + """ + + Parameters + ---------- + metric: str or np.ndarray + str, one of: random, accepted, rejected, accepted-rejected, snr_comps, snr_comps_log, + r_values, cnn_preds. + + Can also pass a 1D array of other metrics + + If np.ndarray, it must be of the same length as the number of components + + cmap: str + custom cmap for the colors + + Returns + ------- + + """ + cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() + n_contours = len(self._image_widget.gridplot[0, 0]["contours"]) + + # use the random colors + if metric == "random": + for subplot in self._image_widget.gridplot: + for i, g in enumerate(subplot["contours"].graphics): + g.colors = self._random_colors[i] + + # set alpha values based on all, accepted, rejected selection + self._set_component_visibility(subplot["contours"], cnmf_obj) + return + + if metric in ["accepted", "rejected"]: + if cmap is None: + cmap = "Set1" + + # make a empty array for cmap_values + classifier = np.zeros(n_contours, dtype=int) + # set the accepted components to 1 + classifier[cnmf_obj.estimates.idx_components] = 1 + + else: + if cmap is None: + cmap = "spring" + + if metric == "snr_comps": + classifier = cnmf_obj.estimates.SNR_comp + + elif metric == "snr_comps_log": + classifier = np.log10(cnmf_obj.estimates.SNR_comp) + + elif metric == "r_values": + classifier = cnmf_obj.estimates.r_values + + elif metric == "cnn_preds": + classifier = cnmf_obj.estimates.cnn_preds + + elif isinstance(metric, np.ndarray): + if not metric.size == n_contours: + raise ValueError(f"If using np.ndarray cor component_colors, the array size must be " + f"the same as n_contours: {n_contours}, your array size is: {metric.size}") + + classifier = metric + + else: + raise ValueError("Invalid colors value") + + for subplot in self._image_widget.gridplot: + # first initialize using a quantitative cmap + # this ensures that setting cmap_values will work + subplot["contours"].cmap = "gray" + + subplot["contours"].cmap_values = classifier + subplot["contours"].cmap = cmap + + self._set_component_visibility(subplot["contours"], cnmf_obj) + + def _set_component_visibility(self, contours: fpl.LineCollection, cnmf_obj): + visible = self._radio_visible_components.value + alpha_invisible = self._spinbox_alpha_invisible_contours.value + + # choose to make all or accepted or rejected visible + if visible == "accepted": + contours[cnmf_obj.estimates.idx_components_bad].colors[:, -1] = alpha_invisible + + elif visible == "rejected": + contours[cnmf_obj.estimates.idx_components].colors[:, -1] = alpha_invisible + + else: + # make everything visible + contours[:].colors[:, -1] = 1 + def show(self, sidecar: bool = False): """ Show the widget @@ -596,6 +734,8 @@ def show(self, sidecar: bool = False): temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) - plots = HBox([temporals, self._image_widget.widget]) + iw_contour_controls = VBox([self._image_widget.widget, self._box_contour_controls]) + + plots = HBox([temporals, iw_contour_controls]) return VBox([self._top_widget, plots]) From 012c8255b1ddc143d890da461e2c32aef96d8aa7 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 02:47:17 -0400 Subject: [PATCH 23/36] component metrics text box, layouting --- mesmerize_viz/_cnmf.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 4caae78..b886911 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -6,9 +6,10 @@ import numpy as np import pandas as pd from ipydatagrid import DataGrid -from ipywidgets import Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink +from ipywidgets import Text, Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax import fastplotlib as fpl +from caiman.source_extraction.cnmf import CNMF from ._utils import DummyMovie @@ -121,6 +122,7 @@ def get_cnmf_data_mapping( } mapping = { + "cnmf_obj": series.cnmf.get_output, "input": ExtensionCallWrapper(series.caiman.get_input_movie, input_movie_kwargs), "rcm": series.cnmf.get_rcm, "rcb": series.cnmf.get_rcb, @@ -293,7 +295,8 @@ def __init__( image_widget_kwargs = dict() default_image_widget_kwargs = { - "cmap": "gnuplot2" + "cmap": "gnuplot2", + "grid_plot_kwargs": {"size": (720, 602)}, } self.image_widget_kwargs = { @@ -319,7 +322,7 @@ def __init__( # ipywidgets for selecting components self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") - self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) + self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1, layout=Layout(width="100px")) for trait in ["value", "max"]: jslink((self.component_slider, trait), (self.component_int_box, trait)) @@ -327,6 +330,14 @@ def __init__( lambda change: self.set_component_index(change["new"]), "value" ) + self._component_metrics_text = Text( + value="", + placeholder="component metrics", + description='metrics:', + disabled=True, + layout=Layout(width="350px") + ) + # checkbox to zoom into components when selected self.checkbox_zoom_components = Checkbox( value=True, @@ -345,7 +356,7 @@ def __init__( # organize these widgets to be shown at the top self._top_widget = VBox([ HBox([self.datagrid, self.params_text_area]), - HBox([self.component_slider, self.component_int_box]), + HBox([self.component_slider, self.component_int_box, self._component_metrics_text]), HBox([self.checkbox_zoom_components, self.zoom_components_scale]) ]) @@ -383,9 +394,9 @@ def __init__( ]) # plots - self._plot_temporal = fpl.Plot() + self._plot_temporal = fpl.Plot(size=(500, 120)) self._plot_temporal.camera.maintain_aspect = False - self._plot_heatmap = fpl.Plot() + self._plot_heatmap = fpl.Plot(size=(500, 450)) self._plot_heatmap.camera.maintain_aspect = False self._image_widget: fpl.ImageWidget = None @@ -436,11 +447,13 @@ def _get_row_data(self, index: int) -> Dict[str, np.ndarray]: images.append(array) contours = data_mapping["contours"]() + cnmf_obj = data_mapping["cnmf_obj"]() data_arrays = { "temporal": temporal, "images": images, "contours": contours, + "cnmf_obj": cnmf_obj } return data_arrays @@ -546,6 +559,13 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=2) + self.component_int_box.value = 0 + self.component_slider.value = 0 + self.component_int_box.max = n_components - 1 + self.component_slider.max = n_components - 1 + + self._cnmf_obj: CNMF = data_arrays["cnmf_obj"] + def _euclidean(self, source, target, event, new_data): """maps click events to contour""" # calculate coms of line collection @@ -590,6 +610,14 @@ def set_component_index(self, index): self._zoom_into_component(index) + self.component_int_box.value = index + + metrics = (f"snr: {self._cnmf_obj.estimates.SNR_comp[index]:.02f}, " + f"r_values: {self._cnmf_obj.estimates.r_values[index]:.02f}, " + f"cnn: {self._cnmf_obj.estimates.cnn_preds[index]:.02f} ") + + self._component_metrics_text.value = metrics + def _zoom_into_component(self, index: int): if not self.checkbox_zoom_components.value: return From 6aed42aeec8d5049acb03ce0dac636e1d81d92f7 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 02:54:16 -0400 Subject: [PATCH 24/36] fix callback --- mesmerize_viz/_cnmf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index b886911..99a51df 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -610,7 +610,11 @@ def set_component_index(self, index): self._zoom_into_component(index) + self.component_int_box.unobserve_all() self.component_int_box.value = index + self.component_int_box.observe( + lambda change: self.set_component_index(change["new"]), "value" + ) metrics = (f"snr: {self._cnmf_obj.estimates.SNR_comp[index]:.02f}, " f"r_values: {self._cnmf_obj.estimates.r_values[index]:.02f}, " From 1f2661492dd66faa355fb1422a083155fdcb6131 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 04:19:31 -0400 Subject: [PATCH 25/36] eval seems to work --- mesmerize_viz/_cnmf.py | 141 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 138 insertions(+), 3 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 99a51df..e6b7293 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from functools import partial from typing import * from warnings import warn @@ -6,7 +7,7 @@ import numpy as np import pandas as pd from ipydatagrid import DataGrid -from ipywidgets import Text, Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink +from ipywidgets import Button, Tab, Text, Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, BoundedFloatText, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax import fastplotlib as fpl from caiman.source_extraction.cnmf import CNMF @@ -138,6 +139,107 @@ def get_cnmf_data_mapping( return mapping +class EvalController: + def __init__(self): + self._float_metrics = [ + "min_SNR", + "SNR_lowest", + "rval_thr", + "rval_lowest", + "min_cnn_thr", + "cnn_lowest", + + ] + + # caiman is really annoying with this + # maps eval metric to the estimates attrs + self._metric_array_mapping = { + "min_SNR": "SNR_comp", + "SNR_lowest": "SNR_comp", + "rval_thr": "r_values", + "rval_lowest": "r_values", + "min_cnn_thr": "cnn_preds", + "cnn_lowest": "cnn_preds", + } + + self._widgets = OrderedDict() + + param_entries = list() + + for metric in self._float_metrics: + slider = FloatSlider(value=0, min=0, max=1, step=0.01, description=metric) + spinbox = BoundedFloatText( + value=0, min=0, max=1, step=0.01, description_tooltip=metric, layout=Layout(width="70px") + ) + + slider.observe(self._call_handlers, "value") + spinbox.observe(self._call_handlers, "value") + + jslink((slider, "value"), (spinbox, "value")) + + param_entries.append(HBox([spinbox, slider])) + + # keep this so it's easier to modify in set_limits + self._widgets[metric] = {"slider": slider, "spinbox": spinbox} + + self.use_cnn_checkbox = Checkbox( + value=True, + description="use_cnn", + description_tooltip="use CNN classifier" + ) + + self.widget = VBox([*param_entries, self.use_cnn_checkbox]) + + self._handlers = list() + + # limits must be set first before it's usable + self._block_handlers = True + + self.button_save_eval = Button(description="Save Eval") + + def set_limits(self, cnmf_obj: CNMF): + self._block_handlers = True + for metric in self._float_metrics: + metric_array = getattr(cnmf_obj.estimates, self._metric_array_mapping[metric]) + for kind in ["slider", "spinbox"]: + # allow 100 steps + self._widgets[metric][kind].step = np.ptp(metric_array) / 100 + self._widgets[metric][kind].min = metric_array.min() + self._widgets[metric][kind].max = metric_array.max() + self._widgets[metric][kind].value = cnmf_obj.params.get_group("quality")[metric] + + self.use_cnn_checkbox.value = cnmf_obj.params.get_group("quality")["use_cnn"] + + self._block_handlers = False + + def get_data(self): + data = dict() + for metric in self._float_metrics: + data[metric] = self._widgets[metric]["spinbox"].value + + data["use_cnn"] = self.use_cnn_checkbox.value + + return data + + def add_handler(self, func: callable): + """Handlers must accept a dict argument, the dict has the eval params""" + self._handlers.append(func) + + def _call_handlers(self, obj): + if self._block_handlers: + return + + data = self.get_data() + for handler in self._handlers: + handler(data) + + def remove_handler(self, func: callable): + self._handlers.remove(func) + + def clear_handlers(self): + self._handlers.clear() + + class CNMFVizContainer: """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" @@ -393,6 +495,13 @@ def __init__( HBox([self._radio_visible_components, self._spinbox_alpha_invisible_contours]) ]) + self._eval_controller = EvalController() + self._eval_controller.add_handler(self._set_eval) + self._eval_controller.button_save_eval.on_click(self._save_eval) + + self._tab_contours_eval = Tab() + self._tab_contours_eval.children = [self._box_contour_controls, self._eval_controller.widget] + # plots self._plot_temporal = fpl.Plot(size=(500, 120)) self._plot_temporal.camera.maintain_aspect = False @@ -564,8 +673,12 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): self.component_int_box.max = n_components - 1 self.component_slider.max = n_components - 1 + # current state of CNMF object + # this can be different from the one in the dataframe if the user uses eval self._cnmf_obj: CNMF = data_arrays["cnmf_obj"] + self._eval_controller.set_limits(self._cnmf_obj) + def _euclidean(self, source, target, event, new_data): """maps click events to contour""" # calculate coms of line collection @@ -678,7 +791,7 @@ def set_component_colors( ------- """ - cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() + cnmf_obj = self._cnmf_obj n_contours = len(self._image_widget.gridplot[0, 0]["contours"]) # use the random colors @@ -751,6 +864,28 @@ def _set_component_visibility(self, contours: fpl.LineCollection, cnmf_obj): # make everything visible contours[:].colors[:, -1] = 1 + def _set_eval(self, eval_params: dict): + index = self._get_selected_row() + # wonky caiman params object stuff + self._cnmf_obj.params.quality.update(eval_params) + + self._cnmf_obj.estimates.filter_components( + imgs=self._dataframe.iloc[index].caiman.get_input_movie(), + params=self._cnmf_obj.params + ) + + # set the colors + colors = self._dropdown_contour_colors.value + self.set_component_colors(colors) + + def _save_eval(self, obj): + index = self._get_selected_row() + + eval_params = self._eval_controller.get_data() + # this overwrites hdf5 file + self._dataframe.iloc[index].cnmf.run_eval(eval_params) + print("Overwrote CNMF object with new eval") + def show(self, sidecar: bool = False): """ Show the widget @@ -766,7 +901,7 @@ def show(self, sidecar: bool = False): temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) - iw_contour_controls = VBox([self._image_widget.widget, self._box_contour_controls]) + iw_contour_controls = VBox([self._image_widget.widget, self._tab_contours_eval]) plots = HBox([temporals, iw_contour_controls]) From 7972877bcf3aa798097b83208fc08f04d1d8ef2a Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 05:04:10 -0400 Subject: [PATCH 26/36] eval works and saves to disk --- mesmerize_viz/_cnmf.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index e6b7293..fe9f7f9 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -10,6 +10,7 @@ from ipywidgets import Button, Tab, Text, Textarea, Layout, HBox, VBox, Checkbox, FloatSlider, BoundedFloatText, IntSlider, BoundedIntText, RadioButtons, Dropdown, jslink from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax import fastplotlib as fpl +from fastplotlib.utils import get_cmap from caiman.source_extraction.cnmf import CNMF @@ -169,7 +170,8 @@ def __init__(self): for metric in self._float_metrics: slider = FloatSlider(value=0, min=0, max=1, step=0.01, description=metric) spinbox = BoundedFloatText( - value=0, min=0, max=1, step=0.01, description_tooltip=metric, layout=Layout(width="70px") + value=0, min=0, max=1, step=0.01, + description_tooltip=metric, layout=Layout(width="70px"), readout_format='.2f', ) slider.observe(self._call_handlers, "value") @@ -188,14 +190,24 @@ def __init__(self): description_tooltip="use CNN classifier" ) - self.widget = VBox([*param_entries, self.use_cnn_checkbox]) - self._handlers = list() # limits must be set first before it's usable self._block_handlers = True - self.button_save_eval = Button(description="Save Eval") + self.button_save_eval = Button( + description="Save Eval", + description_tooltip="Saves CNMF hdf5 file using current evaluation" + ) + + self.button_reset_eval = Button( + description="Reset Eval Params", + description_tooltip="Reset eval params from the current hdf5 file" + ) + + buttons = HBox([self.button_save_eval, self.button_reset_eval]) + + self.widget = VBox([*param_entries, self.use_cnn_checkbox, buttons]) def set_limits(self, cnmf_obj: CNMF): self._block_handlers = True @@ -501,6 +513,7 @@ def __init__( self._tab_contours_eval = Tab() self._tab_contours_eval.children = [self._box_contour_controls, self._eval_controller.widget] + self._tab_contours_eval.titles = ["contour colors", "eval params"] # plots self._plot_temporal = fpl.Plot(size=(500, 120)) @@ -844,8 +857,13 @@ def set_component_colors( # this ensures that setting cmap_values will work subplot["contours"].cmap = "gray" - subplot["contours"].cmap_values = classifier - subplot["contours"].cmap = cmap + if len(np.unique(classifier)) == 1: + # TODO: patch until next fastplotlib release + color = get_cmap(cmap)[0] # set using first color in cmap + subplot["contours"][:].colors = color + else: + subplot["contours"].cmap_values = classifier + subplot["contours"].cmap = cmap self._set_component_visibility(subplot["contours"], cnmf_obj) @@ -878,6 +896,13 @@ def _set_eval(self, eval_params: dict): colors = self._dropdown_contour_colors.value self.set_component_colors(colors) + def _reset_eval(self, obj): + index = self._get_selected_row() + + # get CNMF object from cache + cnmf_obj = self._dataframe.iloc[index].cnmf.get_output() + self._set_eval(cnmf_obj.estimates.params.get_group("quality")) + def _save_eval(self, obj): index = self._get_selected_row() From acd8f22dd1456257c1de6f16e6ec68a84318f340 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 05:11:48 -0400 Subject: [PATCH 27/36] params text --- mesmerize_viz/_cnmf.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index fe9f7f9..4342ae2 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -14,7 +14,7 @@ from caiman.source_extraction.cnmf import CNMF -from ._utils import DummyMovie +from ._utils import DummyMovie, format_params IMAGE_OPTIONS = [ @@ -196,13 +196,12 @@ def __init__(self): self._block_handlers = True self.button_save_eval = Button( - description="Save Eval", - description_tooltip="Saves CNMF hdf5 file using current evaluation" + description="Save Eval to disk", ) self.button_reset_eval = Button( description="Reset Eval Params", - description_tooltip="Reset eval params from the current hdf5 file" + description_tooltip="Reset eval from disk" ) buttons = HBox([self.button_save_eval, self.button_reset_eval]) @@ -358,6 +357,7 @@ def __init__( raise ValueError( "The kwarg `component_indices` is not allowed here." ) + self._set_params_text_area(index) self.reset_timepoint_on_change = reset_timepoint_on_change self.input_movie_kwargs = input_movie_kwargs @@ -510,6 +510,7 @@ def __init__( self._eval_controller = EvalController() self._eval_controller.add_handler(self._set_eval) self._eval_controller.button_save_eval.on_click(self._save_eval) + self._eval_controller.button_reset_eval.on_click(self._reset_eval) self._tab_contours_eval = Tab() self._tab_contours_eval.children = [self._box_contour_controls, self._eval_controller.widget] @@ -580,6 +581,23 @@ def _get_row_data(self, index: int) -> Dict[str, np.ndarray]: return data_arrays + def _set_params_text_area(self, index): + row = self._dataframe.iloc[index] + # try and get the param diffs + try: + param_diffs = self._dataframe.caiman.get_params_diffs( + algo=row["algo"], + item_name=row["item_name"] + ).loc[index] + + diffs_dict = {"diffs": param_diffs.to_dict()} + diffs = f"{format_params(diffs_dict, 0)}\n\n" + except: + diffs = "" + + # diffs and full params + self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) + def _row_changed(self, *args): index = self._get_selected_row() if index is None: @@ -600,6 +618,7 @@ def _row_changed(self, *args): else: # no exceptions, set plots self._set_data(data_arrays) + self._set_params_text_area(index) def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._contour_graphics.clear() @@ -901,7 +920,8 @@ def _reset_eval(self, obj): # get CNMF object from cache cnmf_obj = self._dataframe.iloc[index].cnmf.get_output() - self._set_eval(cnmf_obj.estimates.params.get_group("quality")) + # reset eval + self._set_eval(cnmf_obj.params.get_group("quality")) def _save_eval(self, obj): index = self._get_selected_row() From a65cd64a761e56dd21b74395714f9ebb7dd4e754 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:04:18 -0400 Subject: [PATCH 28/36] manual eval works --- mesmerize_viz/_cnmf.py | 142 +++++++++++++++++++++++++++++++++-------- 1 file changed, 117 insertions(+), 25 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 4342ae2..608e6e7 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -13,6 +13,8 @@ from fastplotlib.utils import get_cmap from caiman.source_extraction.cnmf import CNMF +from mesmerize_core.caiman_extensions.cnmf import cnmf_cache + from ._utils import DummyMovie, format_params @@ -254,6 +256,31 @@ def clear_handlers(self): class CNMFVizContainer: """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" + @property + def component_index(self) -> int: + """Current component index""" + return self._component_index + + @property + def plot_temporal(self) -> fpl.Plot: + """Plot with the single temporal trace""" + return self._plot_temporal + + @property + def plot_heatmap(self) -> fpl.Plot: + """Plot with the heatmap""" + return self._plot_heatmap + + @property + def image_widget(self) -> fpl.ImageWidget: + """ImageWidget""" + return self._image_widget + + @property + def cnmf_obj(self) -> CNMF: + """Current CNMF object displayed in the viewer, use with care""" + return self._cnmf_obj + def __init__( self, dataframe: pd.DataFrame, @@ -528,9 +555,15 @@ def __init__( self._contour_graphics: List[fpl.LineCollection] = list() + self._component_index = 0 + + self._cnmf_obj: CNMF = None + data_arrays = self._get_row_data(index=start_index) self._set_data(data_arrays) + self._set_params_text_area(index=start_index) + def _get_selected_row(self) -> Union[int, None]: r1 = self.datagrid.selections[0]["r1"] r2 = self.datagrid.selections[0]["r2"] @@ -628,24 +661,7 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._plot_temporal.clear() self._plot_heatmap.clear() - if self._image_widget is None: - self._image_widget = fpl.ImageWidget( - data=data_arrays["images"], - names=self._image_data_options, - **self.image_widget_kwargs - ) - - # need to start it here so that we can access the toolbar to link events with the slider - self._image_widget.show() - - else: - # image widget doesn't need clear, we can just use set_data - self._image_widget.set_data(data_arrays["images"]) - for subplot in self._image_widget.gridplot: - if "contours" in subplot: - # delete the contour graphics - subplot.delete_graphic(subplot["contours"]) - + # make our temporal plots first, else image widget slider events could trigger linear selectors self._temporal_data = data_arrays["temporal"] # make temporal graphics @@ -667,6 +683,26 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._synchronizer.add(self._linear_selector_temporal) self._synchronizer.add(self._linear_selector_heatmap) + if self._image_widget is None: + self._image_widget = fpl.ImageWidget( + data=data_arrays["images"], + names=self._image_data_options, + **self.image_widget_kwargs + ) + + self._image_widget.gridplot.renderer.add_event_handler(self._manual_toggle_component, "key_down") + + # need to start it here so that we can access the toolbar to link events with the slider + self._image_widget.show() + + else: + # image widget doesn't need clear, we can just use set_data + self._image_widget.set_data(data_arrays["images"]) + for subplot in self._image_widget.gridplot: + if "contours" in subplot: + # delete the contour graphics + subplot.delete_graphic(subplot["contours"]) + # absolute garbage monkey patch which I will fix once we make ImageWidget emit its own events if hasattr(self._image_widget.sliders["t"], "qslider"): self._image_widget.sliders["t"].qslider.valueChanged.connect(self._set_linear_selector_index_from_image_widget) @@ -693,12 +729,12 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): image_graphic.link( "click", target=contour_graphic, - feature="colors", - new_data="w", + feature="thickness", + new_data=5, callback=self._euclidean ) - contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=2) + # contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=2) self.component_int_box.value = 0 self.component_slider.value = 0 @@ -744,10 +780,13 @@ def set_component_index(self, index): index = index.pick_info["selected_index"] for g in self._contour_graphics: - g.set_feature(feature="colors", new_data="w", indices=index) + g.set_feature(feature="thickness", new_data=8, indices=index) self._plot_temporal["line"].data = self._temporal_data[index] + # set the component index property + self._component_index = index + if self._component_linear_selector._move_info is None: # TODO: Very hacky for now, ignores if the slider is currently being moved by the user # prevents weird slider movement @@ -767,6 +806,7 @@ def set_component_index(self, index): self._component_metrics_text.value = metrics + def _zoom_into_component(self, index: int): if not self.checkbox_zoom_components.value: return @@ -924,13 +964,65 @@ def _reset_eval(self, obj): self._set_eval(cnmf_obj.params.get_group("quality")) def _save_eval(self, obj): + # this overwrites the hdf5 file index = self._get_selected_row() - eval_params = self._eval_controller.get_data() - # this overwrites hdf5 file - self._dataframe.iloc[index].cnmf.run_eval(eval_params) + # delete existing file + path = self._dataframe.iloc[index].cnmf.get_output_path() + path.unlink() + + # save to disk + self._cnmf_obj.save(str(path)) + + # clear the cache + cnmf_cache.clear_cache() + print("Overwrote CNMF object with new eval") + def _manual_toggle_component(self, ev): + if not hasattr(ev, "key"): + return + + if ev.key == "a": + if self.component_index in self._cnmf_obj.estimates.idx_components: + # component already in good + return + # else swap it from bad and put into good + + # remove from bad + self._cnmf_obj.estimates.idx_components_bad = np.delete( + self._cnmf_obj.estimates.idx_components_bad, + self._cnmf_obj.estimates.idx_components_bad == self.component_index + ) + + # put index into good + self._cnmf_obj.estimates.idx_components = np.sort(np.concatenate([ + self._cnmf_obj.estimates.idx_components, + [self.component_index] + ])) + + elif ev.key == "r": + if self.component_index in self._cnmf_obj.estimates.idx_components_bad: + # component already in bad + return + # else swap it from bad and put into good + + # remove from good + self._cnmf_obj.estimates.idx_components = np.delete( + self._cnmf_obj.estimates.idx_components, + self._cnmf_obj.estimates.idx_components == self.component_index + ) + + # put index into bad + self._cnmf_obj.estimates.idx_components_bad = np.sort(np.concatenate([ + self._cnmf_obj.estimates.idx_components_bad, + [self.component_index] + ])) + + # set the colors + colors = self._dropdown_contour_colors.value + self.set_component_colors(colors) + def show(self, sidecar: bool = False): """ Show the widget From 0feaf43641291920be474630dd70502eb1f8bd5f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:05:18 -0400 Subject: [PATCH 29/36] remove old cnmf stuff --- mesmerize_viz/_cnmf/__init__.py | 1 - mesmerize_viz/_cnmf/_eval.py | 71 ---- mesmerize_viz/_cnmf/_extensions.py | 95 ----- mesmerize_viz/_cnmf/_time_array.py | 137 ------- mesmerize_viz/_cnmf/_viz_container.py | 566 -------------------------- mesmerize_viz/_cnmf/_wrapper.py | 536 ------------------------ 6 files changed, 1406 deletions(-) delete mode 100644 mesmerize_viz/_cnmf/__init__.py delete mode 100644 mesmerize_viz/_cnmf/_eval.py delete mode 100644 mesmerize_viz/_cnmf/_extensions.py delete mode 100644 mesmerize_viz/_cnmf/_time_array.py delete mode 100644 mesmerize_viz/_cnmf/_viz_container.py delete mode 100644 mesmerize_viz/_cnmf/_wrapper.py diff --git a/mesmerize_viz/_cnmf/__init__.py b/mesmerize_viz/_cnmf/__init__.py deleted file mode 100644 index a93e4fd..0000000 --- a/mesmerize_viz/_cnmf/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._extensions import CNMFDataFrameVizExtension \ No newline at end of file diff --git a/mesmerize_viz/_cnmf/_eval.py b/mesmerize_viz/_cnmf/_eval.py deleted file mode 100644 index 492229e..0000000 --- a/mesmerize_viz/_cnmf/_eval.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import * -from collections import OrderedDict - -from ipywidgets import FloatSlider, FloatText, HBox, VBox, link, Layout, Label - - -class EvalWidgets: - def __init__(self): - # low thresholds - self._low_thresholds = OrderedDict( - rval_lowest=(-1.0, -1.0, 1.0), # (val, min, max) - SNR_lowest=(0.5, 0., 100), - cnn_lowest=(0.1, 0., 1.0), - ) - - # high thresholds - self.high_thresholds = OrderedDict( - rval_thr=(0.8, 0., 1.0), - min_SNR=(2.5, 0., 100), - min_cnn_thr=(0.9, 0., 1.0), - ) - - self._low_threshold_widget = list() - for k in self._low_thresholds: - kwargs = dict(value=self._low_thresholds[k][0], min=self._low_thresholds[k][1], max=self._low_thresholds[k][2], step=0.01, description=k) - slider = FloatSlider(**kwargs) - entry = FloatText(**kwargs, layout=Layout(width="150px")) - - link((slider, "value"), (entry, "value")) - - setattr(self, f"_{k}", entry) - - self._low_threshold_widget.append(HBox([slider, entry])) - - self._high_threshold_widgets = list() - for k in self.high_thresholds: - kwargs = dict(value=self.high_thresholds[k][0], min=self.high_thresholds[k][1], max=self.high_thresholds[k][2], step=0.01, description=k) - slider = FloatSlider(**kwargs) - entry = FloatText(**kwargs, layout=Layout(width="150px")) - - link((slider, "value"), (entry, "value")) - - setattr(self, f"_{k}", entry) - - self._high_threshold_widgets.append(HBox([slider, entry])) - - self.widget = VBox( - [ - Label("Low Thresholds"), - self._low_threshold_widget, - Label("High Thresholds"), - self._high_threshold_widgets - ] - ) - - def get_params(self): - """get the values from the GUI""" - - eval_params = dict() - for param in self._low_thresholds: - eval_params[param] = getattr(self, f"_{param}.value") - - for param in self._high_threshold_widgets: - eval_params[param] = getattr(self, f"_{param}.value") - - return eval_params - - def set_param(self, param: str, value: float): - w = getattr(self, f"_{param}") - - w.value = value diff --git a/mesmerize_viz/_cnmf/_extensions.py b/mesmerize_viz/_cnmf/_extensions.py deleted file mode 100644 index 0cef1d0..0000000 --- a/mesmerize_viz/_cnmf/_extensions.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import * - -import pandas as pd - -from ._viz_container import CNMFVizContainer - - -@pd.api.extensions.register_dataframe_accessor("cnmf") -class CNMFDataFrameVizExtension: - def __init__(self, df): - self._dataframe = df - - def viz( - self, - data_options: List[str] = None, - start_index: int = None, - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random", - calcium_framerate: float = None, - other_data_loaders: Dict[str, callable] = None, - data_kwargs: dict = None, - data_grid_kwargs: dict = None, - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data_options: list of list of str - default [["temporal"], ["heatmap-zscore"], ["input", "rcm", "rcb", "residuals"]] - - **Note:** You may add suffixes to temporal and heatmap options for "dfof", "zscore", "norm", - examples: "temporal-dfof", "heatmap-norm", "heatmap-zscore", "heatmap-dfof", etc. - - list of data to plot, valid options are: - - +------------------+-----------------------------------------+ - | data option | description | - +------------------+-----------------------------------------+ - | "input" | input movie | - | "rcm" | reconstructed movie, A * C | - | "rcb" | reconstructed background, b * f | - | "residuals" | residuals, input - (A * C) - (b * f) | - | "corr" | correlation image, if computed | - | "pnr" | peak-noise-ratio image, if computed | - | "temporal" | temporal components overlaid | - | "temporal-stack" | temporal components stack | - | "heatmap" | temporal components heatmap | - | "rcm-mean" | rcm mean projection image | - | "rcm-min" | rcm min projection image | - | "rcm-max" | rcm max projection image | - | "rcm-std" | rcm standard deviation projection image | - | "rcb-mean" | rcb mean projection image | - | "rcb-min" | rcb min projection image | - | "rcb-max" | rcb max projection image | - | "rcb-std" | rcb standard deviation projection image | - | "mean" | mean projection image | - | "max" | max projection image | - | "std" | standard deviation projection image | - +------------------+-----------------------------------------+ - - - start_index: int, default 0 - start index item used to set the initial data in the ImageWidget - - reset_timepoint_on_change: bool, default False - reset the timepoint in the ImageWidget when changing items/rows - - data_grid_kwargs: dict, optional - kwargs passed to DataGrid() - - Returns - ------- - McorrVizContainer - widget that contains the DataGrid, params text box and ImageWidget - """ - container = CNMFVizContainer( - dataframe=self._dataframe, - data=data_options, - start_index=start_index, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors, - calcium_framerate=calcium_framerate, - other_data_loaders=other_data_loaders, - data_kwargs=data_kwargs, - data_grid_kwargs=data_grid_kwargs, - ) - - return container diff --git a/mesmerize_viz/_cnmf/_time_array.py b/mesmerize_viz/_cnmf/_time_array.py deleted file mode 100644 index 08143bb..0000000 --- a/mesmerize_viz/_cnmf/_time_array.py +++ /dev/null @@ -1,137 +0,0 @@ -import math -from typing import Union, Tuple - -import numpy as np - -from mesmerize_core.arrays._base import LazyArray -from mesmerize_core.utils import quick_min_max - - -# TODO: maybe this can be used so that ImageWidget can be used for both behavior and calcium -# TODO: but then we need an option to set window_funcs separately for each subplot -class TimeArray(LazyArray): - """ - Wrapper for array-like that takes units of millisecond for slicing - """ - def __init__(self, array: Union[np.ndarray, LazyArray], timestamps = None, framerate = None): - """ - Arrays which can be sliced using timepoints in units of millisecond. - Supports slicing with start and stop timepoints, does not support slice steps. - - i.e. You can do this: time_array[30], time_array[30:], time_array[:50], time_array[30:50]. - You cannot do this: time_array[::10], time_array[0::10], time_array[0:50:10] - - Parameters - ---------- - array: array-like - data array, must have shape attribute and first dimension must be frame index - - timestamps: np.ndarray, 1 dimensional - timestamps in units of millisecond, you must provide either timestamps or framerate. - MUST be in order such that t_(n +1) > t_n for all n. - - framerate: float - framerate, in units of Hz (per second). You must provide either timestamps or framerate - """ - self._array = array - - if timestamps is None and framerate is None: - raise ValueError("Must provide timestamps or framerate") - - if timestamps is None: - # total duration in milliseconds = n_frames / framerate - n_frames = self.shape[0] - stop_time_ms = (n_frames / framerate) * 1000 - timestamps = np.linspace( - start=0, - stop=stop_time_ms, - num=n_frames, - endpoint=False - ) - - if timestamps.size != self._array.shape[0]: - raise ValueError("timestamps.size != array.shape[0]") - - self.timestamps = timestamps - - def _get_closest_index(self, timepoint: float): - """ - from: https://stackoverflow.com/a/26026189/4697983 - - This is very fast, 10 microseconds even for a - - Parameters - ---------- - timepoint: float - timepoint in milliseconds - - Returns - ------- - int - index for the closest timestamp, which also corresponds to the frame index of the data array - """ - value = timepoint - array = self.timestamps - - idx = np.searchsorted(array, value, side="left") - if idx > 0 and (idx == len(array) or math.fabs(value - array[idx - 1]) < math.fabs(value - array[idx])): - return idx - 1 - else: - return idx - - # override __getitem__ since it will not work with LazyArray base implementation since: - # 1. base implementation requires the slice indices to be less than shape[0] - # 2. base implementation does not consider slicing with float values - def __getitem__(self, indices: Union[slice, int, float]) -> np.ndarray: - if isinstance(indices, slice): - if indices.step is not None: - raise IndexError( - "TimeArray slicing does not support step, only start and stop. See docstring." - ) - - if indices.start is None: - start = 0 - else: - start = self._get_closest_index(indices.start) - - if indices.stop is None: - stop = self.n_frames - else: - stop = self._get_closest_index(indices.stop) - - s = slice(start, stop) - return self._array[s] - - # single index - index = self._get_closest_index(indices) - return self._array[index] - - def _compute_at_indices(self, indices: Union[int, slice]) -> np.ndarray: - """not implemented here""" - pass - - @property - def n_frames(self) -> int: - return self.shape[0] - - @property - def shape(self) -> Tuple[int, int, int]: - return self._array.shape - - @property - def dtype(self) -> str: - return str(self._array.dtype) - - @property - def min(self) -> float: - if isinstance(self._array, LazyArray): - return self._array.min - else: - return quick_min_max(self._array)[0] - - @property - def max(self) -> float: - if isinstance(self._array, LazyArray): - return self._array.max - else: - return quick_min_max(self._array)[1] diff --git a/mesmerize_viz/_cnmf/_viz_container.py b/mesmerize_viz/_cnmf/_viz_container.py deleted file mode 100644 index fed0c5e..0000000 --- a/mesmerize_viz/_cnmf/_viz_container.py +++ /dev/null @@ -1,566 +0,0 @@ -import itertools -from _warnings import warn -from functools import partial -from typing import * - -import pandas as pd -from ipydatagrid import DataGrid -from ipywidgets import Textarea, Layout, VBox, HBox, RadioButtons, Dropdown, FloatSlider -from IPython.display import display -from sidecar import Sidecar -import numpy as np -from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesScalerMinMax - -from ._wrapper import ( - VALID_DATA_OPTIONS, GridPlotWrapper, projs, ExtensionCallWrapper, TEMPORAL_OPTIONS, - TEMPORAL_OPTIONS_DFOF, TEMPORAL_OPTIONS_ZSCORE, TEMPORAL_OPTIONS_NORM -) -from .._utils import format_params - - -def get_cnmf_data_mapping(series: pd.Series, data_kwargs: dict = None, other_data_loaders: dict = None) -> dict: - """ - Returns dict that maps data option str to a callable that can return the corresponding data array. - - For example, ``{"input": series.get_input_movie}`` maps "input" -> series.get_input_movie - - Parameters - ---------- - series: pd.Series - row/item to get mapping from - - data_kwargs: dict, optional - optional kwargs for each of the extension functions - - other_data_loaders: dict - {"data_option": callable}, example {"behavior": LazyVideo} - - Returns - ------- - dict - {data label: callable} - """ - if data_kwargs is None: - data_kwargs = dict() - - if other_data_loaders is None: - other_data_loaders = dict() - - default_extension_kwargs = {k: dict() for k in VALID_DATA_OPTIONS + list(other_data_loaders.keys())} - - default_extension_kwargs["contours"] = {"swap_dim": False} - - ext_kwargs = { - **default_extension_kwargs, - **data_kwargs - } - - projections = {k: partial(series.caiman.get_projection, k) for k in projs} - - other_data_loaders_mapping = dict() - - # make ExtensionCallWrapers for other data loaders - for option in list(other_data_loaders.keys()): - other_data_loaders_mapping[option] = ExtensionCallWrapper(other_data_loaders[option], ext_kwargs[option]) - - rcm_rcb_projs = dict() - for proj in ["mean", "min", "max", "std"]: - rcm_rcb_projs[f"rcm-{proj}"] = ExtensionCallWrapper( - series.cnmf.get_rcm, - ext_kwargs["rcm"], - attr=f"{proj}_image" - ) - - temporal_mappings = { - k: ExtensionCallWrapper(series.cnmf.get_temporal, ext_kwargs[k]) for k in TEMPORAL_OPTIONS - } - - dfof_mappings = { - k: ExtensionCallWrapper(series.cnmf.get_detrend_dfof, ext_kwargs[k]) for k in TEMPORAL_OPTIONS_DFOF - } - - zscore_mappings = { - k: ExtensionCallWrapper( - series.cnmf.get_temporal, ext_kwargs[k], post_process_func=TimeSeriesScalerMeanVariance().fit_transform - ) for k in TEMPORAL_OPTIONS_ZSCORE - } - - norm_mappings = { - k: ExtensionCallWrapper( - series.cnmf.get_temporal, ext_kwargs[k], post_process_func=TimeSeriesScalerMinMax().fit_transform - ) for k in TEMPORAL_OPTIONS_NORM - } - - m = { - "input": ExtensionCallWrapper(series.caiman.get_input_movie, ext_kwargs["input"]), - "rcm": ExtensionCallWrapper(series.cnmf.get_rcm, ext_kwargs["rcm"]), - "rcb": ExtensionCallWrapper(series.cnmf.get_rcb, ext_kwargs["rcb"]), - "residuals": ExtensionCallWrapper(series.cnmf.get_residuals, ext_kwargs["residuals"]), - "corr": ExtensionCallWrapper(series.caiman.get_corr_image, ext_kwargs["corr"]), - "contours": ExtensionCallWrapper(series.cnmf.get_contours, ext_kwargs["contours"]), - "empty": None, - **temporal_mappings, - **dfof_mappings, - **zscore_mappings, - **norm_mappings, - **projections, - **rcm_rcb_projs, - **other_data_loaders_mapping - } - - return m - - -class CNMFVizContainer: - """Widget that contains the DataGrid, params text box fastplotlib GridPlot, etc""" - - def __init__( - self, - dataframe: pd.DataFrame, - data: List[str] = None, - start_index: int = None, - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random", - calcium_framerate: float = None, - other_data_loaders: Dict[str, callable] = None, - data_kwargs: dict = None, - data_grid_kwargs: dict = None, - ): - """ - Visualize CNMF output and other data columns such as behavior video (optional) - - Parameters - ---------- - dataframe: pd.DataFrame - - data: list of str, or list of list of str - data options, such as "input", "temporal", "contours", etc. - - start_index - - reset_timepoint_on_change - - calcium_framerate - - other_data_loaders: Dict[str, callable] - if loading non-calcium related data arrays, provide dict of callables for opening them. - Example, if you provide ``data = ["contours", "temporal", "behavior"]``, and the "behavior" - column contains videos, you could provide `other_data_loads = {"behavior": LazyVideo} - - data_kwargs: dict - kwargs passed to corresponding extension function to load data. - example: ``{"temporal": {"component_ixs": "good"}}`` - - gridplot_kwargs: List[dict] - kwargs passed to GridPlot - - data_grid_kwargs - """ - - if data is None: - data = [["temporal"], ["heatmap-zscore"], ["input", "rcm", "rcb", "residuals"]] - # if it's the default options, it will hstack the temporal and heatmap next to the image data - self.default = True - else: - self.default = False - - if other_data_loaders is None: - other_data_loaders = dict() - - # simple list of str, single gridplot - if all(isinstance(option, str) for option in data): - data = [data] - - if not all(isinstance(option, list) for option in data): - raise TypeError( - "Must pass list of str or nested list of str" - ) - - # make sure data options are valid - for d in list(itertools.chain(*data)): - if (d not in VALID_DATA_OPTIONS) and (d not in dataframe.columns): - raise ValueError( - f"`data` options are: {VALID_DATA_OPTIONS} or a DataFrame column name: {dataframe.columns}\n" - f"You have passed: {d}" - ) - - if d in dataframe.columns: - if d not in other_data_loaders.keys(): - raise ValueError( - f"You have provided the non-CNMF related data option: {d}.\n" - f"If you provide a non-cnmf related data option you must also provide a " - f"data loader callable for it to `other_data_loaders`" - ) - - self._other_data_loaders = other_data_loaders - - if data_grid_kwargs is None: - data_grid_kwargs = dict() - - self._dataframe = dataframe - - default_widths = { - "algo": 50, - 'item_name': 200, - 'input_movie_path': 120, - 'algo_duration': 80, - 'comments': 120, - 'uuid': 60 - } - - columns = dataframe.columns - # these add clutter - hide_columns = [ - "params", - "outputs", - "added_time", - "ran_time", - - ] - - df_show = self._dataframe[[c for c in columns if c not in hide_columns]] - - self.datagrid = DataGrid( - df_show, # show only a subset - selection_mode="cell", - layout={"height": "250px", "width": "750px"}, - base_row_size=24, - index_name="index", - column_widths=default_widths, - **data_grid_kwargs - ) - - self.params_text_area = Textarea() - self.params_text_area.layout = Layout( - height="250px", - max_height="250px", - width="360px", - max_width="500px", - disabled=True, - ) - - # data options is private since this can't be changed once an image widget has been made - self._data = data - - if data_kwargs is None: - data_kwargs = dict() - - self.data_kwargs = data_kwargs - - if start_index is None: - # try to guess the start index - start_index = dataframe[dataframe.algo == "cnmf"].iloc[0].name - - self.current_row: int = start_index - - self._random_colors = None - - self._make_gridplot( - start_index=start_index, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors, - ) - - self._set_params_text_area(index=start_index) - - # set initial selected row - self.datagrid.select( - row1=start_index, - column1=0, - row2=start_index, - column2=len(df_show.columns), - clear_mode="all" - ) - - # callback when row changed - self.datagrid.observe(self._row_changed, names="selections") - - self._dropdown_contour_colors = Dropdown( - options=["random", "accepted", "rejected", "snr_comps", "snr_comps_log", "r_values", "cnn_preds"], - value="random", - description='contour colors:', - ) - - self._dropdown_contour_colors.observe(self._ipywidget_set_component_colors, "value") - - self._radio_visible_components = RadioButtons( - options=["all", "accepted", "rejected"], - description_tooltip="contours to make visible", - description="visible contours" - ) - - self._radio_visible_components.observe(self._ipywidget_set_component_colors, "value") - - self._spinbox_alpha_invisible_contours = FloatSlider( - value=0.0, - min=0.0, - max=1.0, - step=0.1, - description="invisible alpha:", - description_tooltip="transparency of contours set to be invisible", - disabled=False - ) - - self._spinbox_alpha_invisible_contours.observe(self._ipywidget_set_component_colors, "value") - - self._box_contour_controls = VBox([ - self._dropdown_contour_colors, - HBox([self._radio_visible_components, self._spinbox_alpha_invisible_contours]) - ]) - - self.sidecar = None - - def _make_gridplot( - self, - start_index: int, - reset_timepoint_on_change: bool, - data_graphic_kwargs: dict, - gridplot_kwargs: dict, - cmap: str, - component_colors: str, - ): - - data_mapping = get_cnmf_data_mapping( - self._dataframe.iloc[start_index], - self.data_kwargs - ) - - cnmf_obj = self._dataframe.iloc[start_index].cnmf.get_output() - n_contours = cnmf_obj.estimates.C.shape[0] - - self._random_colors = np.random.rand(n_contours, 4).astype(np.float32) - self._random_colors[:, -1] = 1 - - self._gridplot_wrapper = GridPlotWrapper( - data=self._data, - data_mapping=data_mapping, - reset_timepoint_on_change=reset_timepoint_on_change, - data_graphic_kwargs=data_graphic_kwargs, - gridplot_kwargs=gridplot_kwargs, - cmap=cmap, - component_colors=component_colors - - ) - - self.gridplots = self._gridplot_wrapper.gridplots - - def show(self, sidecar: bool = True): - """Show the widget""" - - # create gridplots and start render loop - gridplots = [gp.show(sidecar=False) for gp in self.gridplots] - - # contour color controls and auto-zoom - contour_controls = VBox( - [ - HBox([self._gridplot_wrapper.checkbox_zoom_components, self._gridplot_wrapper.zoom_components_scale]), - self._box_contour_controls - ] - ) - - if "Jupyter" in self.gridplots[0].canvas.__class__.__name__: - if self.default: - # TODO: let's just make this the mandatory behavior, temporal + heatmap on left, any image stuff on right - # temporal and heatmap on left side, image data on right side - gridplot_elements = HBox([VBox(gridplots[:2]), VBox([gridplots[2], contour_controls])]) - else: - gridplot_elements = VBox(gridplots) - else: - raise NotImplemented("show() not implemented outside of jupyter") - gridplot_elements = list() - - if self.sidecar is None: - self.sidecar = Sidecar() - - widget = VBox( - [ - HBox([self.datagrid, self.params_text_area]), - HBox([self._gridplot_wrapper.component_slider, self._gridplot_wrapper.component_int_box]), - gridplot_elements, - ] - ) - - with self.sidecar: - return display(widget) - - def close(self): - """Close the widget""" - for gp in self.gridplots: - gp.close() - - def _get_selection_row(self) -> Union[int, None]: - r1 = self.datagrid.selections[0]["r1"] - r2 = self.datagrid.selections[0]["r2"] - - if r1 != r2: - warn("Only single row selection is currently allowed") - return - - # get corresponding dataframe index from currently visible dataframe - # since filtering etc. is possible - index = self.datagrid.get_visible_data().index[r1] - - return index - - def _row_changed(self, *args): - index = self._get_selection_row() - if index is None: - return - - if self.current_row == index: - return - - try: - data_mapping = get_cnmf_data_mapping( - self._dataframe.iloc[index], - self.data_kwargs - ) - self._gridplot_wrapper.change_data(data_mapping) - except Exception as e: - self.params_text_area.value = f"{type(e).__name__}\n" \ - f"{str(e)}\n\n" \ - f"See jupyter log for details" - raise e - - self._set_params_text_area(index) - - cnmf_obj = self._dataframe.iloc[index].cnmf.get_output() - n_contours = cnmf_obj.estimates.C.shape[0] - - self._random_colors = np.random.rand(n_contours, 4).astype(np.float32) - self._random_colors[:, -1] = 1 - - self.current_row = index - - def _set_params_text_area(self, index): - row = self._dataframe.iloc[index] - # try and get the param diffs - try: - param_diffs = self._dataframe.caiman.get_params_diffs( - algo=row["algo"], - item_name=row["item_name"] - ).iloc[index] - - diffs_dict = {"diffs": param_diffs} - diffs = f"{format_params(diffs_dict, 0)}\n\n" - except: - diffs = "" - - # diffs and full params - self.params_text_area.value = diffs + format_params(self._dataframe.iloc[index].params, 0) - - - @property - def cmap(self) -> str: - return self._gridplot_wrapper.cmap - - @cmap.setter - def cmap(self, cmap: str): - for g in self._gridplot_wrapper.image_graphics: - g.cmap = cmap - - def _set_component_visibility(self, contours, cnmf_obj): - visible = self._radio_visible_components.value - alpha_invisible = self._spinbox_alpha_invisible_contours.value - - # choose to make all or accepted or rejected visible - if visible == "accepted": - contours[cnmf_obj.estimates.idx_components_bad].colors[:, -1] = alpha_invisible - - elif visible == "rejected": - contours[cnmf_obj.estimates.idx_components].colors[:, -1] = alpha_invisible - - else: - # make everything visible - contours[:].colors[:, -1] = 1 - - def _ipywidget_set_component_colors(self, *args): - """just a wrapper to make ipywidgets happy""" - colors = self._dropdown_contour_colors.value - self.set_component_colors(colors) - - def set_component_colors( - self, - colors: Union[str, np.ndarray], - cmap: str = None, - ): - """ - - Parameters - ---------- - colors: str or np.ndarray - np.ndarray or one of: random, accepted, rejected, accepted-rejected, snr_comps, snr_comps_log, - r_values, cnn_preds - - If np.ndarray, it must be of the same length as the number of components - - cmap: str - custom cmap for the colors - - Returns - ------- - - """ - cnmf_obj = self._dataframe.iloc[self.current_row].cnmf.get_output() - n_contours = len(self._gridplot_wrapper.contour_graphics[0]) - - if colors == "random": - colors = self._random_colors - for contours in self._gridplot_wrapper.contour_graphics: - for i, g in enumerate(contours.graphics): - g.colors = colors[i] - - self._set_component_visibility(contours, cnmf_obj) - - return - - if colors in ["accepted", "rejected"]: - if cmap is None: - cmap = "Set1" - - # make a empty array for cmap_values - classifier = np.zeros(n_contours, dtype=int) - # set the accepted components to 1 - classifier[cnmf_obj.estimates.idx_components] = 1 - - else: - if cmap is None: - cmap = "spring" - - if colors == "snr_comps": - classifier = cnmf_obj.estimates.SNR_comp - - elif colors == "snr_comps_log": - classifier = np.log10(cnmf_obj.estimates.SNR_comp) - - elif colors == "r_values": - classifier = cnmf_obj.estimates.r_values - - elif colors == "cnn_preds": - classifier = cnmf_obj.estimates.cnn_preds - - elif isinstance(colors, np.ndarray): - if not colors.size == n_contours: - raise ValueError(f"If using np.ndarray cor component_colors, the array size must be " - f"the same as n_contours: {n_contours}, your array size is: {colors.size}") - - classifier = colors - - else: - raise ValueError("Invalid colors value") - - for contours in self._gridplot_wrapper.contour_graphics: - # first initialize using a quantitative cmap - # this ensures that setting cmap_values will work - contours.cmap = "gray" - - contours.cmap_values = classifier - contours.cmap = cmap - - self._set_component_visibility(contours, cnmf_obj) diff --git a/mesmerize_viz/_cnmf/_wrapper.py b/mesmerize_viz/_cnmf/_wrapper.py deleted file mode 100644 index 67b5f3d..0000000 --- a/mesmerize_viz/_cnmf/_wrapper.py +++ /dev/null @@ -1,536 +0,0 @@ -from itertools import product -from typing import Union, List, Dict - -import numpy as np -from ipywidgets import IntSlider, BoundedIntText, jslink, Checkbox, FloatSlider, RadioButtons - -from fastplotlib import GridPlot, graphics -from fastplotlib.graphics.selectors import LinearSelector, Synchronizer -from fastplotlib.utils import calculate_gridshape - - -# basic data options -VALID_DATA_OPTIONS = [ - "contours", - "empty" -] - - -IMAGE_OPTIONS = [ - "input", - "rcm", - "rcb", - "residuals", - "corr", - "pnr", -] - -rcm_rcb_proj_options = list() -# RCM and RCB projections -for option in ["rcm", "rcb"]: - for proj in ["mean", "min", "max", "std"]: - rcm_rcb_proj_options.append(f"{option}-{proj}") - -IMAGE_OPTIONS += rcm_rcb_proj_options - -TEMPORAL_OPTIONS = [ - "temporal", - "temporal-stack", - "heatmap", -] - -TEMPORAL_OPTIONS_DFOF = [ - f"{option}-dfof" for option in TEMPORAL_OPTIONS -] - -TEMPORAL_OPTIONS_ZSCORE = [ - f"{option}-zscore" for option in TEMPORAL_OPTIONS -] - -TEMPORAL_OPTIONS_NORM = [ - f"{option}-norm" for option in TEMPORAL_OPTIONS -] - -TEMPORAL_OPTIONS_ALL = TEMPORAL_OPTIONS + TEMPORAL_OPTIONS_DFOF + TEMPORAL_OPTIONS_ZSCORE + TEMPORAL_OPTIONS_ZSCORE + TEMPORAL_OPTIONS_NORM - -projs = [ - "mean", - "max", - "std", -] - -IMAGE_OPTIONS += projs - -VALID_DATA_OPTIONS += IMAGE_OPTIONS -VALID_DATA_OPTIONS += TEMPORAL_OPTIONS_ALL - - -class ExtensionCallWrapper: - def __init__( - self, - extension_func: callable, - kwargs: dict = None, - attr: str = None, - post_process_func: callable = None, - ): - """ - Basically a very fancy ``functools.partial``. - - In addition to behaving like ``functools.partial``, it supports: - - kwargs - - returning attributes of the return value from the callable - - postprocessing the return value - - Parameters - ---------- - extension_func: callable - extension function reference - - kwargs: dict - kwargs to pass to the extension function when it is called - - attr: str, optionalself, extension_func: callable, kwargs: dict = None, attr: str = None - return an attribute of the callable's output instead of the return value of the callable. - Example: if using rcm, can set ``attr="max_image"`` to return the max proj of the RCM. - - post_process_func: callable - A function to postprocess before returning, such as zscore, etc. - """ - - if kwargs is None: - self.kwargs = dict() - else: - self.kwargs = kwargs - - self.func = extension_func - self.attr = attr - self.post_process_func = post_process_func - - def __call__(self, *args, **kwargs): - rval = self.func(**self.kwargs) - - if self.attr is not None: - return getattr(rval, self.attr) - - if self.post_process_func is not None: - return self.post_process_func(rval) - - return rval - - -class GridPlotWrapper: - """Wraps GridPlot in a way that allows updating the data""" - - def __init__( - self, - data: Union[List[str], List[List[str]]], - data_mapping: Dict[str, ExtensionCallWrapper], - reset_timepoint_on_change: bool = False, - data_graphic_kwargs: dict = None, - # slider_ipywidget: ipywidgets.IntSlider = None, - gridplot_kwargs: dict = None, - cmap: str = "gnuplot2", - component_colors: str = "random" - ): - """ - Visualize motion correction output. - - Parameters - ---------- - data: list of str or list of list of str - list of data to plot, examples: ["input", "temporal-stack"], [["temporal"], ["rcm", "rcb"]] - - data_mapping: dict - maps {"data_option": callable} - - reset_timepoint_on_change: bool, default False - reset the timepoint in the ImageWidget when changing items/rows - - data_graphic_kwargs: dict - passed add_ for corresponding graphic - - slider_ipywidget: ipywidgets.IntSlider - time slider from ImageWidget - - gridplot_kwargs: dict, optional - kwargs passed to GridPlot - - """ - - self._data = data - - if data_graphic_kwargs is None: - data_graphic_kwargs = dict() - - self.data_graphic_kwargs = data_graphic_kwargs - - if gridplot_kwargs is None: - gridplot_kwargs = dict() - - self.cmap = cmap - - self.component_colors = component_colors - - # self._slider_ipywidget = slider_ipywidget - - self.reset_timepoint_on_change = reset_timepoint_on_change - - self.gridplots: List[GridPlot] = list() - - self.component_slider = IntSlider(min=0, max=1, value=0, step=1, description="component index:") - self.component_int_box = BoundedIntText(min=0, max=1, value=0, step=1) - for trait in ["value", "max"]: - jslink((self.component_slider, trait), (self.component_int_box, trait)) - - self.component_int_box.observe( - lambda change: self.set_component_index(change["new"]), "value" - ) - - # gridplot for each sublist - for sub_data in self._data: - # make the kwargs - final_gridplot_kwargs = { - "shape": calculate_gridshape(len(sub_data)), - "controllers": "sync" - } - # merge with any use-specified kwargs - # user-specified kwargs will override anything specified here - - final_gridplot_kwargs.update(gridplot_kwargs) - - # instantiate gridplot and add to list of gridplots - self.gridplots.append( - GridPlot(**final_gridplot_kwargs) - ) - - self.temporal_graphics: List[graphics.LineGraphic] = list() - self.temporal_stack_graphics: List[graphics.LineStack] = list() - self.heatmap_graphics: List[graphics.HeatmapGraphic] = list() - self.image_graphics: List[graphics.ImageGraphic] = list() - self.contour_graphics: List[graphics.LineCollection] = list() - - self.heatmap_component_ix_selectors: List[LinearSelector] = list() # selects heatmap rows, i.e. components - - self._managed_graphics: List[list] = [ - self.temporal_graphics, - self.temporal_stack_graphics, - self.image_graphics, - self.contour_graphics - ] - - # to store only image data in a 1:1 mapping to the graphics list - self.image_graphic_arrays: List[np.ndarray] = list() - - self.linear_selectors: List[LinearSelector] = list() # select current timepoint, i.e. frame index - - self._synchronizer: Synchronizer = Synchronizer(key_bind=None) # synchronizes linear_selectors - - self._current_frame_index: int = 0 - - self._current_temporal_components: np.ndarray = None - - self.checkbox_zoom_components = Checkbox( - value=True, - description="auto-zoom component", - description_tooltip="If checked, zoom into selected component" - ) - - self.zoom_components_scale = FloatSlider( - min=0.25, - max=3, - value=1, - step=0.25, - description="zoom scale", - description_tooltip="scale if zoom components is checked" - ) - - self.change_data(data_mapping) - - def _zoom_into_component(self, index: int): - if not self.checkbox_zoom_components.value: - return - - for gridplot in self.gridplots: - for subplot in gridplot: - if "contours" not in subplot: - continue - - subplot.camera.show_object( - subplot["contours"].graphics[index].world_object, - scale=self.zoom_components_scale.value - ) - - def set_component_index(self, index: int): - # TODO: more elegant way than skip_heatmap - for g in self.contour_graphics: - g.set_feature(feature="colors", new_data="w", indices=index) - - for g in self.temporal_graphics: - g.data = self._current_temporal_components[index] - - for s in self.heatmap_component_ix_selectors: - # TODO: Very hacky for now, ignores if the slider is currently being moved, prevents weird slider movement - if s._move_info is None: - s.selection = index - - self.component_int_box.value = index - - self._zoom_into_component(index) - - def _heatmap_set_component_index(self, ev): - index = ev.pick_info["selected_index"] - - if ev.pick_info["pygfx_event"] is None: - # this means that the selector was not triggered by the user but that it moved due to another event - # so we don't set_component_index because then infinite recursion - return - - self.set_component_index(index) - - def _parse_data(self, data_options, data_mapping) -> List[List[np.ndarray]]: - """ - Returns nested list of array-like - """ - data_arrays = list() - - for d in data_options: - if isinstance(d, list): - data_arrays.append(self._parse_data(d, data_mapping)) - - elif d == "empty": - data_arrays.append(None) - - else: - func = data_mapping[d] - a = func() - data_arrays.append(a) - - return data_arrays - - def change_data(self, data_mapping: Dict[str, callable]): - """ - Changes the data shown in the gridplot. - - Clears all the gridplots, makes and adds new graphics - - Parameters - ---------- - data_mapping - - Returns - ------- - - """ - for l in self._managed_graphics: - l.clear() - - self._synchronizer.clear() # must clear synchronizer first before the selectors, else lingering weakrefs - self.heatmap_component_ix_selectors.clear() - self.linear_selectors.clear() - - self.image_graphic_arrays.clear() - - # clear out old array that stores temporal components - self._current_temporal_components = None - - # clear existing subplots - for gp in self.gridplots: - gp.clear() - - # new data arrays - data_arrays = self._parse_data(data_options=self._data, data_mapping=data_mapping) - - # rval is (contours, centeres of masses) - contours = data_mapping["contours"]()[0] - - if self.component_colors == "random": - n_components = len(contours) - component_colors = np.random.rand(n_components, 4).astype(np.float32) - component_colors[:, -1] = 1 - else: - component_colors = self.component_colors - - self.component_slider.value = 0 - self.component_slider.max = len(contours) - 1 - - # change data for all gridplots - for sub_data, sub_data_arrays, gridplot in zip(self._data, data_arrays, self.gridplots): - self._change_data_gridplot(sub_data, sub_data_arrays, gridplot, contours, component_colors) - - # connect events - self._connect_events() - - for ls in self.linear_selectors: - ls.selection.add_event_handler(self.set_frame_index) - self._synchronizer.add(ls) # sync linear_selectors - - for hs in self.heatmap_component_ix_selectors: - hs.selection.add_event_handler(self._heatmap_set_component_index) - - def _change_data_gridplot( - self, - data: List[str], - data_arrays: List[np.ndarray], - gridplot: GridPlot, - contours, - component_colors - ): - """ - Changes data in a single gridplot. - - Create the corresponding graphics. - - Parameters - ---------- - data - data_arrays - gridplot - contours - component_colors - - Returns - ------- - - """ - - if self.reset_timepoint_on_change: - self._current_frame_index = 0 - - for data_option, data_array, subplot in zip(data, data_arrays, gridplot): - if data_option in self.data_graphic_kwargs.keys(): - graphic_kwargs = self.data_graphic_kwargs[data_option] - else: - graphic_kwargs = dict() - # skip - if data_option == "empty": - continue - - elif data_option.startswith("temporal") and "stack" not in data_option: - # Only few one line at a time - current_graphic = subplot.add_line( - data_array[0], - colors="w", - name="line", - **graphic_kwargs - ) - - current_graphic.data.add_event_handler(subplot.auto_scale) - self.temporal_graphics.append(current_graphic) - - if self._current_temporal_components is None: - self._current_temporal_components = data_array - - # otherwise the plot has nothing in it which causes issues - # subplot.add_line(np.random.rand(data_array.shape[1]), colors=(0, 0, 0, 0), name="pseudo-line") - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - elif data_option.startswith("temporal-stack"): - current_graphic = subplot.add_line_stack( - data_array, - colors=component_colors, - name="lines", - **graphic_kwargs - ) - self.temporal_stack_graphics.append(current_graphic) - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - elif data_option.startswith("heatmap"): - current_graphic = subplot.add_heatmap( - data_array, - name="heatmap", - **graphic_kwargs - ) - self.heatmap_graphics.append(current_graphic) - - # scale according to temporal dims - subplot.camera.maintain_aspect = False - - selector = current_graphic.add_linear_selector( - axis="y", - color=(1, 1, 1, 0.5), - thickness=5, - ) - - self.heatmap_component_ix_selectors.append(selector) - - else: - # else it is an image - if data_array.ndim == 3: - frame = data_array[self._current_frame_index] - else: - frame = data_array - img_graphic = subplot.add_image( - frame, - cmap=self.cmap, - name="image", - **graphic_kwargs - ) - - self.image_graphics.append(img_graphic) - self.image_graphic_arrays.append(data_array) - - contour_graphic = subplot.add_line_collection( - contours, - colors=component_colors, - name="contours" - ) - - self.contour_graphics.append(contour_graphic) - - subplot.name = data_option - - if data_option in TEMPORAL_OPTIONS_ALL: - self.linear_selectors.append(current_graphic.add_linear_selector()) - subplot.camera.maintain_aspect = False - - def _euclidean(self, source, target, event, new_data): - """maps click events to contour""" - # calculate coms of line collection - indices = np.array(event.pick_info["index"]) - - coms = list() - - for contour in target.graphics: - coors = contour.data()[~np.isnan(contour.data()).any(axis=1)] - com = coors.mean(axis=0) - coms.append(com) - - # euclidean distance to find closest index of com - indices = np.append(indices, [0]) - - ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0]) - - self.set_component_index(ix) - - self.component_int_box.value = ix - - return None - - def _connect_events(self): - for image_graphic, contour_graphic in zip(self.image_graphics, self.contour_graphics): - image_graphic.link( - "click", - target=contour_graphic, - feature="colors", - new_data="w", - callback=self._euclidean - ) - - contour_graphic.link("colors", target=contour_graphic, feature="thickness", new_data=5) - - for cg, tsg in product(self.contour_graphics, self.temporal_stack_graphics): - cg.link("colors", target=contour_graphic, feature="colors", new_data="w", bidirectional=True) - - def set_frame_index(self, ev): - # 0 because this will return the same number repeated * n_components - index = ev.pick_info["selected_index"] - for image_graphic, full_array in zip(self.image_graphics, self.image_graphic_arrays): - # txy data - if full_array.ndim > 2: - image_graphic.data = full_array[index] - - self._current_frame_index = index From fbbf4894b0832ae73356159159b8b06471e1c594 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:18:35 -0400 Subject: [PATCH 30/36] add extension --- mesmerize_viz/_cnmf.py | 120 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 10 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 608e6e7..5b61be9 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -14,6 +14,7 @@ from caiman.source_extraction.cnmf import CNMF from mesmerize_core.caiman_extensions.cnmf import cnmf_cache +from mesmerize_core import CNMFExtensions from ._utils import DummyMovie, format_params @@ -317,13 +318,8 @@ def __init__( if dfof-norm: uses cnmf.get_dfof() and then 0-1 normalizes - reset_timepoint_on_change: bool - - temporal_postprocess: optional, list of str or callable - - heatmap_postprocess: str, None, callable - if str: one of "norm", "dfof", "zscore" - Or a callable to postprocess using your own function + reset_timepoint_on_change: bool, default False + reset the timepoint in the ImageWidget when changing items/rows temporal_kwargs: dict kwargs passed to cnmf.get_temporal(), example: {"add_residuals" : True}. @@ -332,6 +328,11 @@ def __init__( input_movie_kwargs: dict kwargs passed to caiman.get_input() + image_widget_kwargs: dict + kwargs passed to ImageWidget + + Example: `image_widget_kwargs={"cmap": "viridis"}` + data_grid_kwargs """ @@ -384,7 +385,6 @@ def __init__( raise ValueError( "The kwarg `component_indices` is not allowed here." ) - self._set_params_text_area(index) self.reset_timepoint_on_change = reset_timepoint_on_change self.input_movie_kwargs = input_movie_kwargs @@ -697,7 +697,11 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): else: # image widget doesn't need clear, we can just use set_data - self._image_widget.set_data(data_arrays["images"]) + self._image_widget.set_data( + data_arrays["images"], + reset_indices=self.reset_timepoint_on_change, + reset_vmin_vmax=True + ) for subplot in self._image_widget.gridplot: if "contours" in subplot: # delete the contour graphics @@ -806,7 +810,6 @@ def set_component_index(self, index): self._component_metrics_text.value = metrics - def _zoom_into_component(self, index: int): if not self.checkbox_zoom_components.value: return @@ -1043,3 +1046,100 @@ def show(self, sidecar: bool = False): plots = HBox([temporals, iw_contour_controls]) return VBox([self._top_widget, plots]) + + +@pd.api.extensions.register_dataframe_accessor("cnmf") +class CNMFDataFrameVizExtension: + def __init__(self, df): + self._dataframe = df + + def viz( + self, + start_index: int = None, + temporal_data_option: str = None, + image_data_options: list[str] = None, + temporal_kwargs: dict = None, + reset_timepoint_on_change: bool = False, + input_movie_kwargs: dict = None, + image_widget_kwargs: dict = None, + data_grid_kwargs: dict = None, + ): + """ + Visualize CNMF output and other data columns such as behavior video (optional). + + Note: If using dfof temporal_data_option, you must have already run dfof. + + Parameters + ---------- + dataframe: pd.DataFrame + + start_index: int + + temporal_data_option: optional, str + if not provided or ``None`: uses cnmf.get_temporal() + + if zscore: uses zscore of cnmf.get_temporal() + + if norm: uses 0-1 normalized output of cnmf.get_temporal() + + if dfof: uses cnmf.get_dfof() + + if dfof-zscore: uses cnmf.get_dfof() and then zscores + + if dfof-norm: uses cnmf.get_dfof() and then 0-1 normalizes + + image_data_options: list of str + default: ["input", "rcm", "rcb", "residuals"] + + Valid options: + + +------------------+-----------------------------------------+ + | option | description | + +------------------+-----------------------------------------+ + | "input" | input movie | + | "rcm" | reconstructed movie, A * C | + | "rcb" | reconstructed background, b * f | + | "residuals" | residuals, input - (A * C) - (b * f) | + | "rcm-mean" | rcm mean projection image | + | "rcm-min" | rcm min projection image | + | "rcm-max" | rcm max projection image | + | "rcm-std" | rcm standard deviation projection image | + | "rcb-mean" | rcb mean projection image | + | "rcb-min" | rcb min projection image | + | "rcb-max" | rcb max projection image | + | "rcb-std" | rcb standard deviation projection image | + | "mean" | mean projection image | + | "max" | max projection image | + | "std" | standard deviation projection image | + +------------------+-----------------------------------------+ + + reset_timepoint_on_change: bool, default False + reset the timepoint in the ImageWidget when changing items/rows + + temporal_kwargs: dict + kwargs passed to cnmf.get_temporal(), example: {"add_residuals" : True}. + Ignored if temporal_data_option contains "dfof" + + input_movie_kwargs: dict + kwargs passed to caiman.get_input() + + image_widget_kwargs: dict + kwargs passed to ImageWidget + + Example: `image_widget_kwargs={"cmap": "viridis"}` + """ + + container = CNMFVizContainer( + dataframe=self._dataframe, + start_index=start_index, + temporal_data_option=temporal_data_option, + image_data_options=image_data_options, + temporal_kwargs=temporal_kwargs, + reset_timepoint_on_change=reset_timepoint_on_change, + input_movie_kwargs=input_movie_kwargs, + image_widget_kwargs=image_widget_kwargs, + data_grid_kwargs=data_grid_kwargs, + + ) + + return container From 476c5e89b150a5c3f8d6fcbddb81d5526782cb3d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:18:47 -0400 Subject: [PATCH 31/36] update setup and version --- mesmerize_viz/__init__.py | 2 +- setup.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mesmerize_viz/__init__.py b/mesmerize_viz/__init__.py index c15d829..e9ebdcb 100644 --- a/mesmerize_viz/__init__.py +++ b/mesmerize_viz/__init__.py @@ -1,2 +1,2 @@ from ._mcorr import MCorrDataFrameVizExtension -from ._cnmf import CNMFDataFrameVizExtension \ No newline at end of file +from ._cnmf import CNMFDataFrameVizExtension diff --git a/setup.py b/setup.py index ab75695..a5e6bda 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ install_requires = [ - "mesmerize-core", "fastplotlib[notebook]>=v0.1.0.a14", "ipydatagrid", "tslearn", From d7942d890cb6408c2680f559e613bc331874104e Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:19:13 -0400 Subject: [PATCH 32/36] update utils --- mesmerize_viz/_utils.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/mesmerize_viz/_utils.py b/mesmerize_viz/_utils.py index 0390daf..f19581a 100644 --- a/mesmerize_viz/_utils.py +++ b/mesmerize_viz/_utils.py @@ -1,9 +1,6 @@ -from itertools import chain -from functools import wraps from typing import * import numpy as np -from mesmerize_core.arrays._base import LazyArray # to format params dict into yaml-like string @@ -15,37 +12,6 @@ ) if isinstance(d, dict) else str(d) -def validate_data_options(): - def dec(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if "data_options" in kwargs: - data_options = kwargs["data_options"] - else: - if len(args) > 0: - data_options = args[0] - else: - # assume the extension func will take care of it - # the default data arg is None is nothing is passed - return func(self, *args, **kwargs) - - # flatten - if any([isinstance(d, (list, tuple)) for d in data_options]): - data_options = list(chain.from_iterable(data_options)) - - valid_options = list(self._data_mapping.keys()) - - for d in data_options: - if d not in valid_options: - raise KeyError(f"Invalid data option: \"{d}\", valid options are:" - f"\n{valid_options}") - return func(self, *args, **kwargs) - - return wrapper - - return dec - - class DummyMovie: """Really really hacky""" def __init__(self, image: np.ndarray, shape, ndim, size): From 902dc316d99437c3bf390cdbe08b17c8f4e49fd4 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:19:23 -0400 Subject: [PATCH 33/36] update version --- mesmerize_viz/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesmerize_viz/VERSION b/mesmerize_viz/VERSION index a69aed0..b7dd9ec 100644 --- a/mesmerize_viz/VERSION +++ b/mesmerize_viz/VERSION @@ -1 +1 @@ -0.1.0.a1 +0.1.0.b1 From 24be97bca943e1c495db8f94305589f49785eb45 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:36:34 -0400 Subject: [PATCH 34/36] qt context for cnmf viz --- mesmerize_viz/_cnmf.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 5b61be9..2048ac2 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -12,6 +12,8 @@ import fastplotlib as fpl from fastplotlib.utils import get_cmap from caiman.source_extraction.cnmf import CNMF +from IPython.display import display +from sidecar import Sidecar from mesmerize_core.caiman_extensions.cnmf import cnmf_cache from mesmerize_core import CNMFExtensions @@ -1039,13 +1041,28 @@ def show(self, sidecar: bool = False): """ - temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) + if self.image_widget.gridplot.canvas.__class__.__name__ == "JupyterWgpuCanvas": + temporals = VBox([self._plot_temporal.show(), self._plot_heatmap.show()]) + plots = HBox([temporals, self._image_widget.widget]) + widget = VBox([self._top_widget, plots, self._tab_contours_eval]) + if sidecar: + with Sidecar(): + return display(widget) + else: + return widget - iw_contour_controls = VBox([self._image_widget.widget, self._tab_contours_eval]) + elif self.image_widget.gridplot.canvas.__class__.__name__ == "QWgpuCanvas": + self.plot_temporal.show() + self.plot_heatmap.show() + self.image_widget.show() - plots = HBox([temporals, iw_contour_controls]) + widget = VBox([self._top_widget, self._tab_contours_eval]) - return VBox([self._top_widget, plots]) + return widget + else: + raise EnvironmentError( + "No available output context. Make sure you're running in jupyterlab or using %gui qt" + ) @pd.api.extensions.register_dataframe_accessor("cnmf") From 9cda51fce20fa0d0e6dca4f2277b84d7f1b6b0a5 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 06:49:06 -0400 Subject: [PATCH 35/36] bug fixes --- mesmerize_viz/_cnmf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mesmerize_viz/_cnmf.py b/mesmerize_viz/_cnmf.py index 2048ac2..6775687 100644 --- a/mesmerize_viz/_cnmf.py +++ b/mesmerize_viz/_cnmf.py @@ -654,6 +654,7 @@ def _row_changed(self, *args): # no exceptions, set plots self._set_data(data_arrays) self._set_params_text_area(index) + self.current_row = index def _set_data(self, data_arrays: Dict[str, np.ndarray]): self._contour_graphics.clear() @@ -704,6 +705,10 @@ def _set_data(self, data_arrays: Dict[str, np.ndarray]): reset_indices=self.reset_timepoint_on_change, reset_vmin_vmax=True ) + + for g in self._image_widget.managed_graphics: + g.registered_callbacks.clear() + for subplot in self._image_widget.gridplot: if "contours" in subplot: # delete the contour graphics @@ -772,8 +777,6 @@ def _euclidean(self, source, target, event, new_data): self.set_component_index(ix) - self.component_int_box.value = ix - return None def set_component_index(self, index): From 44a45355fcf24f342cff040326793376231e474d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 1 Nov 2023 07:01:50 -0400 Subject: [PATCH 36/36] update example nbs --- examples/cnmf.ipynb | 154 ++++++++++++++++++------------------------- examples/mcorr.ipynb | 74 +++++---------------- 2 files changed, 79 insertions(+), 149 deletions(-) diff --git a/examples/cnmf.ipynb b/examples/cnmf.ipynb index e5daca9..5ef4af8 100644 --- a/examples/cnmf.ipynb +++ b/examples/cnmf.ipynb @@ -2,29 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "c0e29f1f-a33b-45ab-877a-29291bb5fd01", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-10-14 03:49:11.570098: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-10-14 03:49:11.599719: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-10-14 03:49:11.599748: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-10-14 03:49:11.599774: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-10-14 03:49:11.605509: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-10-14 03:49:11.606166: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-10-14 03:49:12.496589: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", - "Possible incompatible version of wgpu:\n", - " Detected 0.11.0, need >=0.10.0, <0.11.0.\n" - ] - } - ], + "outputs": [], "source": [ "from mesmerize_core import *\n", "from mesmerize_viz import *" @@ -32,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "326b9f89-7bcd-4b76-be74-9d1d84b96ca2", "metadata": { "tags": [] @@ -48,89 +31,78 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "e95b0266-5056-48d0-9410-df0d6fdb151e", + "execution_count": null, + "id": "c5ba8e2e-e1ec-447c-8cdb-4e6b35680e11", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# optional\n", + "# %gui qt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "914fb902-1476-4ba8-9b8e-85c338211fba", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:460: UserWarning: Index name of 'index' is not round-trippable.\n", - " schema = pd.io.json.build_table_schema(dataframe)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5081da17f0024189a932a8a1e41b2892", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "52f6f09085584836b3cbd5afb51e7139", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "edf7f52fd7bc4b378fb43c4c33c9a854", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/kushal/repos/fastplotlib/fastplotlib/graphics/_features/_base.py:34: UserWarning: converting float64 array to float32\n", - " warn(f\"converting {array.dtype} array to float32\")\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_cnmf/_viz_container.py:425: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", - "CaimanDataFrameExtensions.get_params_diffs\n", - "This feature is new and the might improve in the future\n", - "\n", - " param_diffs = self._dataframe.caiman.get_params_diffs(\n" - ] - } - ], + "outputs": [], + "source": [ + "viz = df.cnmf.viz(temporal_kwargs={\"add_residuals\": True})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0427b7d-b14b-4cfe-91af-d555c92e5557", + "metadata": { + "tags": [] + }, + "outputs": [], "source": [ - "container_widget = df.cnmf.viz(data_options=[[\"temporal\"], [\"heatmap-zscore\"], [\"input\", \"rcm\"]], start_index=1)" + "viz.show()" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "d603672a-f43c-42f3-bf9a-de4499e27659", + "execution_count": null, + "id": "510b2710-744a-4e26-a485-e494d22fbd81", "metadata": {}, "outputs": [], "source": [ - "container_widget.show()" + "viz.cnmf_obj.estimates.idx_components" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "517e9e01-0f45-49b0-b1b7-ff0a9b3a03bb", + "metadata": {}, + "outputs": [], + "source": [ + "viz.cnmf_obj.estimates.idx_components_bad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78c3bcf5-7147-44cc-a9af-c8c0969b9940", + "metadata": {}, + "outputs": [], + "source": [ + "df.iloc[1].cnmf.get_output().estimates.idx_components, df.iloc[1].cnmf.get_output().estimates.idx_components_bad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c32a7f90-6121-4e2e-8cd3-399cf0d5acd4", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/mcorr.ipynb b/examples/mcorr.ipynb index 55f9230..c037d85 100644 --- a/examples/mcorr.ipynb +++ b/examples/mcorr.ipynb @@ -10,29 +10,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "e4d01194-8f05-4221-821f-3550e7e039f7", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-10-14 03:49:01.457161: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-10-14 03:49:01.489296: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-10-14 03:49:01.489325: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-10-14 03:49:01.489351: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-10-14 03:49:01.495237: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2023-10-14 03:49:01.496002: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-10-14 03:49:02.279430: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", - "Possible incompatible version of wgpu:\n", - " Detected 0.11.0, need >=0.10.0, <0.11.0.\n" - ] - } - ], + "outputs": [], "source": [ "from mesmerize_core import *\n", "from mesmerize_viz import *" @@ -40,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "1225545f-9056-435e-bfef-53ff98376c8e", "metadata": { "tags": [] @@ -55,68 +38,43 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "562fb85c-aedc-4d7e-8810-7846fe74e47f", - "metadata": {}, + "execution_count": null, + "id": "c2580d2b-b577-42a2-ae6b-5a89a90f0678", + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "from mesmerize_viz._utils import format_params" + "# optional\n", + "%gui qt" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "7b5d4c32-c5e7-42ed-8e39-55fcda337bcc", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/kushal/venvs/mescore/lib/python3.11/site-packages/ipydatagrid/datagrid.py:460: UserWarning: Index name of 'index' is not round-trippable.\n", - " schema = pd.io.json.build_table_schema(dataframe)\n", - "/home/kushal/repos/mesmerize-viz/mesmerize_viz/_mcorr.py:311: FutureWarning: You are trying to use the following experimental feature, this may change in the future without warning:\n", - "CaimanDataFrameExtensions.get_params_diffs\n", - "This feature is new and the might improve in the future\n", - "\n", - " param_diffs = self._dataframe.caiman.get_params_diffs(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b5fa0c34225f441581d8fd223e60a894", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "container_widget = df.mcorr.viz(data_options=[\"input\", \"mcorr\", \"mean\", \"corr\"])" + "viz = df.mcorr.viz(data_options=[\"input\", \"mcorr\", \"mean\", \"corr\"])" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "c7fac51e-5d91-4204-af25-4fce40ae20d5", "metadata": {}, "outputs": [], "source": [ - "container_widget.show()" + "viz.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "50355864-53e1-44e4-898a-6499a061c0fd", + "id": "661e3f21-dd49-4d14-a058-7d395f873fd0", "metadata": {}, "outputs": [], "source": []