Skip to content

Commit

Permalink
Merge pull request #2 from blab/export-optimal-projected-frequencies
Browse files Browse the repository at this point in the history
Save optimal projected frequencies for fixed model
  • Loading branch information
huddlej authored Sep 4, 2024
2 parents 1406e14 + c58dbdc commit f64a31c
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/popcast/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
# Estimate target values.
y_hat = self.predict(X, coefficients)

# Save optimal frequencies by timepoint and strain, if calculating
# optimal distance to the future.
if calculate_optimal_distance:
optimal_frequency_records = []

# Calculate EMD for each timepoint in the estimated values and sum that
# distance across all timepoints.
error = 0.0
Expand Down Expand Up @@ -471,6 +476,12 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
cost=distance_matrix
)

optimal_frequency_records.append(pd.DataFrame({
"strain": samples_a,
"timepoint": timepoint,
"optimal_projected_frequency": estimated_frequencies,
}))

# Estimate the distance between the model's estimated future and the
# observed future populations.
model_emd, _, self.model_flow = cv2.EMD(
Expand All @@ -490,6 +501,9 @@ def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distan
else:
l1_penalty = 0.0

if calculate_optimal_distance:
self.optimal_frequencies = pd.concat(optimal_frequency_records)

return error + l1_penalty

def _fit_distance(self, coefficients, X, y, use_l1_penalty=True):
Expand Down Expand Up @@ -861,6 +875,11 @@ def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None

# Get the estimated frequencies for test sets to export.
test_y_hat = model.predict(test_X)
test_y_hat = test_y_hat.merge(
model.optimal_frequencies,
on=["timepoint", "strain"],
validate="1:1",
)

# Convert timestamps to a serializable format.
for df in [test_X, test_y, test_y_hat]:
Expand Down

0 comments on commit f64a31c

Please # to comment.