From e7a622141e904960f05984ebacd18d1b34af1d6b Mon Sep 17 00:00:00 2001 From: Mariya Lysenkova Wiklander Date: Tue, 26 Nov 2024 11:03:52 +0100 Subject: [PATCH] Add functions to count classes appearing in prediction sets --- src/conformist/validation_run.py | 19 +++++++++++++++++++ src/conformist/validation_trial.py | 18 ++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/conformist/validation_run.py b/src/conformist/validation_run.py index de1f47d..9fd29cb 100644 --- a/src/conformist/validation_run.py +++ b/src/conformist/validation_run.py @@ -117,6 +117,25 @@ def mean_fnrs_by_class(self, sets, class_names): averages[key] = sum(sizes) / len(sizes) return averages + def prediction_counts_by_class(self, class_names, coocurring_only=False): + counts = {} + for i in range(len(self.prediction_sets)): + labels = self.prediction_sets[i] + + # Get corresponding values from class_names + pset_class_names = [class_names[i] for i, label in enumerate(labels) if label == 1] + + do_count = True + if coocurring_only: + do_count = len(pset_class_names) > 1 + + if do_count: + for class_name in pset_class_names: + class_counts = counts.get(class_name, 0) + class_counts += 1 + counts[class_name] = class_counts + return counts + def run_reports(self, base_output_dir): mean_set_sizes = self.mean_set_sizes_by_class(self.class_names) mean_fnrs = self.mean_fnrs_by_class(self.prediction_sets, self.class_names) diff --git a/src/conformist/validation_trial.py b/src/conformist/validation_trial.py index 01316a6..6bca1fc 100644 --- a/src/conformist/validation_trial.py +++ b/src/conformist/validation_trial.py @@ -115,6 +115,24 @@ def mean_fnrs_by_class(self, class_names): means[class_name] = statistics.mean(d_means) return means + def mean_prediction_counts_by_class(self, class_names, coocurring_only=False): + prediction_count_dicts = [] + for run in self.runs: + prediction_count_dicts.append(run.prediction_counts_by_class(class_names, coocurring_only)) + + means = {} + for class_name in class_names: + d_means = [] + for d in prediction_count_dicts: + if class_name in d: + d_means.append(d[class_name]) + if len(d_means) > 0: + means[class_name] = statistics.mean(d_means) + + # Sort by count descending + means = dict(sorted(means.items(), key=lambda item: item[1], reverse=True)) + return means + def run_reports(self, base_output_dir): self.create_output_dir(base_output_dir) self.visualize_empirical_fnr()