From c58dbdc4077079b57b2d9a1bfbe29e6b14d27b17 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Tue, 3 Sep 2024 15:10:00 -0700 Subject: [PATCH] Save optimal projected frequencies for fixed model Saves optimal projected frequencies per timepoint and strain from the optimal earth mover's distance calculation between each timepoint into the same scores data structure that stores the "projected_frequency" values. --- src/popcast/fit.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/popcast/fit.py b/src/popcast/fit.py index fee9c0e..b1b05a5 100644 --- a/src/popcast/fit.py +++ b/src/popcast/fit.py @@ -421,6 +421,11 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan # Estimate target values. y_hat = self.predict(X, coefficients) + # Save optimal frequencies by timepoint and strain, if calculating + # optimal distance to the future. + if calculate_optimal_distance: + optimal_frequency_records = [] + # Calculate EMD for each timepoint in the estimated values and sum that # distance across all timepoints. error = 0.0 @@ -471,6 +476,12 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan cost=distance_matrix ) + optimal_frequency_records.append(pd.DataFrame({ + "strain": samples_a, + "timepoint": timepoint, + "optimal_projected_frequency": estimated_frequencies, + })) + # Estimate the distance between the model's estimated future and the # observed future populations. model_emd, _, self.model_flow = cv2.EMD( @@ -490,6 +501,9 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan else: l1_penalty = 0.0 + if calculate_optimal_distance: + self.optimal_frequencies = pd.concat(optimal_frequency_records) + return error + l1_penalty def _fit_distance(self, coefficients, X, y, use_l1_penalty=True): @@ -861,6 +875,11 @@ def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None # Get the estimated frequencies for test sets to export. test_y_hat = model.predict(test_X) + test_y_hat = test_y_hat.merge( + model.optimal_frequencies, + on=["timepoint", "strain"], + validate="1:1", + ) # Convert timestamps to a serializable format. for df in [test_X, test_y, test_y_hat]: