diff --git a/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py b/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py index ab479a3..176c407 100644 --- a/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py +++ b/cryoet_data_portal_neuroglancer/precompute/contrast_limits.py @@ -1,10 +1,13 @@ """Methods for computing contrast limits for Neuroglancer image layers.""" -from abc import abstractmethod +from pathlib import Path from typing import Optional +import matplotlib.pyplot as plt import numpy as np from scipy.signal import find_peaks +from sklearn.cluster import KMeans +from sklearn.mixture import GaussianMixture def _restrict_volume_around_central_z_slice( @@ -36,8 +39,7 @@ def _restrict_volume_around_central_z_slice( if z_radius is None: lowest_points = find_peaks(-standard_deviation_per_z_slice, prominence=0.1)[0] if len(lowest_points) < 2: - # TODO create fallback instead - raise ValueError("Not enough low points found") + raise ValueError("Not enough low points found to auto compute z-radius.") for value in lowest_points: if value < central_z_slice: z_min = value @@ -108,7 +110,6 @@ def trim_volume_around_central_zslice( z_radius, ) - @abstractmethod def contrast_limits_from_percentiles( self, low_percentile: float = 1.0, @@ -155,3 +156,97 @@ def contrast_limits_from_mean( width = multipler * rms_value return mean_value - width, mean_value + width + + +class GMMContrastLimitCalculator(ContrastLimitCalculator): + + def __init__(self, volume: Optional["np.ndarray"] = None, num_components: int = 3): + """Initialize the contrast limit calculator. + + Parameters + ---------- + volume: np.ndarray or None, optional. + Input volume for calculating contrast limits. + num_components: int, optional. + The number of components to use for GMM. + By default 3. + """ + super().__init__(volume) + self.num_components = num_components + # cov_type in ["spherical", "diag", "tied", "full"] + self.gmm_estimator = GaussianMixture( + n_components=num_components, + covariance_type="full", + max_iter=200, + random_state=0, + ) + + def contrast_limits_from_gmm(self) -> tuple[float, float]: + """Calculate the contrast limits using Gaussian Mixture Model. + + Returns + ------- + tuple[float, float] + The calculated contrast limits. + """ + self.gmm_estimator.fit(self.volume.reshape(-1, 1)) + + # Get the stats for the gaussian which sits in the middle + means = self.gmm_estimator.means_.flatten() + covariances = self.gmm_estimator.covariances_.flatten() + + return means[1] - 2 * covariances[1], means[1] + 2 * covariances[1] + + +class KMeansContrastLimitCalculator(ContrastLimitCalculator): + + def __init__(self, volume: Optional["np.ndarray"] = None, num_clusters: int = 3): + """Initialize the contrast limit calculator. + + Parameters + ---------- + volume: np.ndarray or None, optional. + Input volume for calculating contrast limits. + num_clusters: int, optional. + The number of clusters to use for KMeans. + By default 3. + """ + super().__init__(volume) + self.num_clusters = num_clusters + self.kmeans_estimator = KMeans(n_clusters=num_clusters, random_state=0) + + def plot_kmeans_clusters(self, output_filename: Optional[str | Path] = None) -> None: + """Plot the KMeans clusters.""" + fig, ax = plt.subplots() + + ax.hist(self.volume.flatten(), bins=100, alpha=0.5) + ax.hist(self.kmeans_estimator.cluster_centers_, bins=100, alpha=0.5) + if output_filename: + fig.savefig(output_filename) + else: + plt.show() + plt.close(fig) + + def contrast_limits_from_kmeans(self) -> tuple[float, float]: + """Calculate the contrast limits using KMeans clustering. + + Parameters + ---------- + num_clusters: int, optional. + The number of clusters to use for KMeans. + By default 3. + + Returns + ------- + tuple[float, float] + The calculated contrast limits. + """ + self.kmeans_estimator.fit(self.volume.reshape(-1, 1)) + + cluster_centers = self.kmeans_estimator.cluster_centers_ + cluster_centers.sort() + + return cluster_centers[0], cluster_centers[-1] + + +# Other possibility is to take the derivative of the histogram and find the peaks