Skip to content

Commit

Permalink
feat: clean up contrast limits
Browse files Browse the repository at this point in the history
  • Loading branch information
seankmartin committed Oct 17, 2024
1 parent 356ca03 commit 87e25e5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 205 deletions.
247 changes: 44 additions & 203 deletions cryoet_data_portal_neuroglancer/precompute/contrast_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import decimate, find_peaks
from scipy.signal import find_peaks
from sklearn.mixture import GaussianMixture

from cryoet_data_portal_neuroglancer.utils import ParameterOptimizer
Expand All @@ -32,11 +32,10 @@ def euclidean_distance(x: tuple[float, float], y: tuple[float, float]) -> float:
return np.sqrt(((x[0] - y[0]) ** 2) + ((x[1] - y[1]) ** 2))


# TODO fix this to work with dask data, see the mesh changes for reference
def _restrict_volume_around_central_z_slice(
volume: "np.ndarray",
central_z_slice: Optional[int] = None,
z_radius: Optional[int] = None,
z_radius: Optional[int] = 5,
) -> "np.ndarray":
"""Restrict a 3D volume to a region around a central z-slice.
Expand All @@ -49,7 +48,9 @@ def _restrict_volume_around_central_z_slice(
By default None, in which case the central z-slice is the middle slice.
z_radius: int, optional.
The number of z-slices to include above and below the central z-slice.
By default None, in which case it is auto computed.
By default 5,
If it is None, it is auto computed - but this can be problematic for large volumes.
Hence the default is a fixed value.
Returns
-------
Expand Down Expand Up @@ -233,6 +234,7 @@ def _objective_function(self, params):
return self.compute_contrast_limit(params["low_percentile"], params["high_percentile"])

def _define_parameter_space(self, parameter_optimizer: ParameterOptimizer):
"""NOTE: the range here is very small, for real-tuning, it should be larger."""
parameter_optimizer.space_creator(
{
"low_percentile": {"type": "randint", "args": [1, 2]},
Expand Down Expand Up @@ -296,6 +298,9 @@ def compute_contrast_limit(
covariance_type = "full"
sample_data = self.volume.flatten()

# Find the best number of components - using BIC
# BIC is a criterion for model selection among a finite set of models
# The model with the lowest BIC is preferred.
bics = np.zeros(shape=(max_components, 2))
for n in range(1, max_components + 1):
n = int(n)
Expand All @@ -315,7 +320,9 @@ def compute_contrast_limit(
min_bic_index = np.argmin(bics[:, 0])
best_n = int(bics[min_bic_index, 1])

variance_multi_dict = {
# With less components, we need to be more conservative
# Hence the standard_deviation multiplier is higher
std_dev_multi_dict = {
1: (2.0, 0.5),
2: (2.2, 0.65),
3: (3.0, 0.8),
Expand All @@ -330,24 +337,19 @@ def compute_contrast_limit(
init_params="k-means++",
)
self.gmm_estimator.fit(sample_data.reshape(-1, 1))
# Get the stats for the gaussian which sits in the middle
means = self.gmm_estimator.means_.flatten()
covariances = self.gmm_estimator.covariances_

# The shape depends on `covariance_type`::
# (n_components,) if 'spherical',
# (n_features, n_features) if 'tied',
# (n_components, n_features) if 'diag',
# (n_components, n_features, n_features) if 'full'
variances = covariances.flatten()
# Extract the means and variances
means = self.gmm_estimator.means_.flatten()
covariances = self.gmm_estimator.covariances_ # (n_components, n_features, n_features)
variances = covariances.flatten() # n_features is 1, so this is n_components

# Pick the GMM component which is closest to the mean of the volume
volume_mean = np.mean(sample_data)
closest_mean_index = np.argmin(np.abs(means - volume_mean))
mean_to_use = means[closest_mean_index]
std_to_use = np.sqrt(variances[closest_mean_index])

low_variance_mult, high_variance_mult = variance_multi_dict[best_n]
low_variance_mult, high_variance_mult = std_dev_multi_dict[best_n]

low_limit, high_limit = (
mean_to_use - low_variance_mult * std_to_use,
Expand All @@ -365,30 +367,14 @@ def _objective_function(self, params):
)

def _define_parameter_space(self, parameter_optimizer):
"""NOTE: the range here is very small, for real-tuning, it should be larger."""
parameter_optimizer.space_creator(
{
"low_variance_mult": {"type": "uniform", "args": [2.2, 2.21]},
"high_variance_mult": {"type": "uniform", "args": [0.6, 0.61]},
},
)

def plot(self, output_filename: Optional[str | Path] = None) -> None:
"""Plot the GMM clusters."""
fig, ax = plt.subplots()

ax.plot(
self.gmm_estimator.means_.flatten(),
[np.sqrt(y) for y in self.gmm_estimator.covariances_.flatten()],
"o",
)
ax.set_xlabel("Mean")
ax.set_ylabel("Standard Deviation")
if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)


class CDFContrastLimitCalculator(ContrastLimitCalculator):

Expand All @@ -402,23 +388,21 @@ def __init__(self, volume: Optional["np.ndarray"] = None):
"""
super().__init__(volume)
self.cdf = None
self.limits = None
self.second_derivative = None

def automatic_parameter_estimation(self):
def automatic_parameter_estimation(self, gradient_threshold=0.3):
_, _, gradient, _ = self._caculate_cdf(n_bins=512)

largest_peak = np.argmax(gradient)
peak_gradient = gradient[largest_peak]
# Find the start gradient percentage
# Before the gradient climbs above 20% of the peak gradient
# Find the median values of the gradient
start_of_rising = np.where(gradient > 0.3 * peak_gradient)[0][0]
start_of_rising = np.where(gradient > gradient_threshold * peak_gradient)[0][0]
mean_before_rising = np.mean(gradient[:start_of_rising])
start_gradient_threshold = mean_before_rising / peak_gradient

# Find the end gradient percentage
end_of_flattening = np.where(gradient[start_of_rising:] < 0.3 * peak_gradient)[0][0]
end_of_flattening = np.where(gradient[start_of_rising:] < gradient_threshold * peak_gradient)[0][0]
mean_after_rising = np.median(gradient[start_of_rising + end_of_flattening :])
end_gradient_threshold = mean_after_rising / peak_gradient

Expand All @@ -442,13 +426,17 @@ def _caculate_cdf(self, n_bins):
@compute_with_timer
def compute_contrast_limit(
self,
start_gradient: float = 0.08,
end_gradient: float = 0.08,
start_multiplier: float = 1.0,
end_multiplier: float = 0.4,
gradient_threshold: float = 0.3,
) -> tuple[float, float]:
"""Calculate the contrast limits using the Cumulative Distribution Function.
Parameters
----------
gradient_threshold: float, optional.
The threshold multiplier against the peak gradient.
This is used to estimate the start and end of the contrast limits.
By default 0.3.
Returns
-------
tuple[float, float]
Expand All @@ -461,187 +449,44 @@ def compute_contrast_limit(
largest_peak = np.argmax(gradient)
peak_gradient_value = gradient[largest_peak]

start_gradient, end_gradient = self.automatic_parameter_estimation()
# Find the start and end gradient percentages
start_gradient, end_gradient = self.automatic_parameter_estimation(gradient_threshold)

# Find where the gradient starts rising and starts flattening after the peak
start_of_rising = np.where(gradient > start_gradient * peak_gradient_value)[0][0]
# Find the first point after the largest peak where the gradient is less than 0.1 * peak_gradient_value
end_of_flattening = np.where(gradient[largest_peak:] < end_gradient * peak_gradient_value)[0][0]
end_of_flattening += largest_peak

start_value = bin_edges[start_of_rising]
end_value = bin_edges[end_of_flattening]
middle_value = bin_edges[largest_peak]
start_to_middle = middle_value - start_value
middle_to_end = end_value - middle_value
start_limit = middle_value - start_multiplier * start_to_middle
end_limit = middle_value + end_multiplier * middle_to_end
start_limit = bin_edges[start_of_rising]
end_limit = bin_edges[end_of_flattening]

self.cdf = [x, cdf]
try:
self.limits = (start_limit.compute(), end_limit.compute())
limits = (start_limit.compute(), end_limit.compute())
except AttributeError:
self.limits = (start_limit, end_limit)
limits = (start_limit, end_limit)

# Ensure that the limits are within the range of the volume
self.limits = (
max(self.limits[0], np.min(self.volume)),
min(self.limits[1], np.max(self.volume)),
return (
float(max(limits[0], np.min(self.volume.flatten()))),
float(min(limits[1], np.max(self.volume.flatten()))),
)
self.first_derivative = gradient
self.second_derivative = np.gradient(gradient)

return self.limits

def _objective_function(self, params):
return self.compute_contrast_limit(
params["start_gradient"],
params["end_gradient"],
params["start_multiplier"],
params["end_multiplier"],
)
return self.compute_contrast_limit(params["gradient_treshold"])

def _define_parameter_space(self, parameter_optimizer):
parameter_optimizer.space_creator(
{
"start_gradient": {"type": "uniform", "args": [0.01, 0.3]},
"end_gradient": {"type": "uniform", "args": [0.01, 0.3]},
"start_multiplier": {"type": "uniform", "args": [1.0, 1.0001]},
"end_multiplier": {"type": "uniform", "args": [1.0, 1.00001]},
"gradient_treshold": {"type": "uniform", "args": [0.05, 0.6]},
},
)

def plot(self, output_filename: Optional[str | Path] = None, real_limits: Optional[list] = None) -> None:
"""Plot the CDF and the calculated limits."""
fig, ax = plt.subplots()

ax.plot(self.cdf[0], self.cdf[1])
ax.axvline(self.limits[0], color="r")
ax.axvline(self.limits[1], color="r")

if real_limits:
ax.axvline(real_limits[0], color="b")
ax.axvline(real_limits[1], color="b")

ax.plot(self.cdf[0], self.first_derivative * 100, "y")
ax.plot(self.cdf[0], self.second_derivative * 100, "g")

if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)


class SignalDecimationContrastLimitCalculator(ContrastLimitCalculator):

def __init__(self, volume: Optional["np.ndarray"] = None):
"""Initialize the contrast limit calculator.
Parameters
----------
volume: np.ndarray or None, optional.
Input volume for calculating contrast limits.
"""
super().__init__(volume)
self.cdf = None
self.limits = None
self.decimation = None

@compute_with_timer
def compute_contrast_limit(
self,
downsample_factor: int = 5,
sample_factor: float = 0.10,
threshold_factor: float = 0.01,
) -> tuple[float, float]:
"""Calculate the contrast limits using decimation.
Returns
-------
tuple[float, float]
The calculated contrast limits.
"""
# Calculate the histogram of the volume
n_bins = 512
min_value = np.min(self.volume.flatten())
max_value = np.max(self.volume.flatten())
hist, _ = np.histogram(self.volume.flatten(), bins=n_bins, range=[min_value, max_value])

# Calculate the CDF of the histogram
cdf = np.cumsum(hist) / np.sum(hist)
x = np.linspace(min_value, max_value, n_bins)

# Downsampling the CDF
y_decimated = decimate(cdf, downsample_factor)
x_decimated = np.linspace(np.min(x), np.max(x), len(y_decimated))

# Calculate the absolute differences between consecutive points in the decimated CDF
diff_decimated = np.abs(np.diff(y_decimated))

# Compute threshold and lower_change threshold
sample_size = int(sample_factor * len(diff_decimated))

initial_flat = np.mean(cdf[:sample_size]) # Average of first points (assumed flat region)
final_flat = np.mean(cdf[-sample_size:]) # Average of last points (assumed flat region)
midpoint = (initial_flat + final_flat) / 2
curve_threshold = threshold_factor * midpoint

# Detect start and end of slope
start_idx_decimated = np.argmax(diff_decimated > curve_threshold) # First large change
end_idx_decimated = (
np.argmax(diff_decimated[start_idx_decimated + 1 :] < curve_threshold) + start_idx_decimated
) # first small change

# Map back the indices to original values
self.cdf = [x, cdf]
self.limits = (
(x_decimated[start_idx_decimated], x_decimated[end_idx_decimated])
if end_idx_decimated != -1
else (None, None)
)

return self.limits

def _objective_function(self, params):
return self.compute_contrast_limit(
params["downsample_factor"],
params["sample_factor"],
params["threshold_factor"],
)

def _define_parameter_space(self, parameter_optimizer):
parameter_optimizer.space_creator(
{
"downsample_factor": {"type": "randint", "args": [3, 7]},
"sample_factor": {"type": "uniform", "args": [0.01, 0.1]},
"threshold_factor": {"type": "uniform", "args": [0.005, 0.2]},
},
)

def plot(self, output_filename: Optional[str | Path] = None, real_limits: Optional[list] = None) -> None:
"""Plot the CDF and the calculated limits."""
fig, ax = plt.subplots()

ax.plot(self.cdf[0], self.cdf[1])
ax.axvline(self.limits[0], color="r")
ax.axvline(self.limits[1], color="r")

if real_limits:
ax.axvline(real_limits[0], color="b")
ax.axvline(real_limits[1], color="b")

if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)


def combined_contrast_limit_plot(
cdf: list[list[float], list[float]],
real_limits: tuple[float, float],
limits_dict: dict[str, tuple[float, float]],
output_filename: Optional[str | Path] = None,
output_filename: str | Path,
) -> None:
"""Plot the CDF and the calculated limits."""
fig, ax = plt.subplots()
Expand All @@ -653,7 +498,7 @@ def combined_contrast_limit_plot(
from matplotlib.lines import Line2D

custom_lines = [Line2D([0], [0], color="b", lw=4)]
colors_dict = {"gmm": "g", "cdf": "y", "decimation": "r"}
colors_dict = {"gmm": "g", "cdf": "y"}
min_x = real_limits[0]
max_x = real_limits[1]
for key, limits in limits_dict.items():
Expand All @@ -674,8 +519,4 @@ def combined_contrast_limit_plot(
legend.append(key + " Limits")
ax.legend(custom_lines, legend)

if output_filename:
fig.savefig(output_filename)
else:
plt.show()
plt.close(fig)
fig.savefig(output_filename)
Loading

0 comments on commit 87e25e5

Please # to comment.