From fb5bd083f0ab6f2572e989a1c7c0b8f04219904a Mon Sep 17 00:00:00 2001 From: Gleb <41195376+zhukgleb@users.noreply.github.com> Date: Sat, 9 Dec 2023 02:43:22 +0500 Subject: [PATCH] feat: add correlation mode to template_correlation (#1114) * feat: add correlation mode to template_correlation * Apply suggestions from code review * Clean up trailing whitespace --------- Co-authored-by: Ricky O'Steen <39831871+rosteen@users.noreply.github.com> --- specutils/analysis/correlation.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/specutils/analysis/correlation.py b/specutils/analysis/correlation.py index 7bb6e6ee7..434fa9825 100644 --- a/specutils/analysis/correlation.py +++ b/specutils/analysis/correlation.py @@ -4,6 +4,7 @@ from astropy.nddata import StdDevUncertainty from astropy.units import Quantity from scipy.signal.windows import tukey +from scipy.signal import correlate from ..manipulation import LinearInterpolatedResampler from .. import Spectrum1D @@ -14,7 +15,7 @@ def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS, - apodization_window=0.5, resample=True): + apodization_window=0.5, resample=True, method="direct"): """ Compute cross-correlation of the observed and template spectra. @@ -49,6 +50,11 @@ def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS, ``template_logwl_resample(spectrum, template, delta_log_wavelength=.1)``. If False, *no* resampling is performed (and the user is responsible for a sensible resampling). + method: str + If you choose "FFT", the correlation will be done through the use + of convolution and will be calculated faster (for small spectral + resolutions it is often correct), otherwise the correlation is determined + directly from sums (the "direct" method in `~scipy.signal.correlate`). Returns ------- @@ -84,9 +90,14 @@ def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS, normalization = 1. # Correlate - corr = np.correlate(observed_log_spectrum.flux.value, - (template_log_spectrum.flux.value * normalization), - mode='full') + if method.lower() == "fft": + corr = correlate(observed_log_spectrum.flux.value, + (template_log_spectrum.flux.value * normalization), + method="fft") + else: + corr = correlate(observed_log_spectrum.flux.value, + (template_log_spectrum.flux.value * normalization), + method="direct") # Compute lag # wave_l is the wavelength array equally spaced in log space.