Skip to content

Commit

Permalink
Add functions to count classes appearing in prediction sets
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 26, 2024
1 parent 47b8f6a commit e7a6221
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/conformist/validation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ def mean_fnrs_by_class(self, sets, class_names):
averages[key] = sum(sizes) / len(sizes)
return averages

def prediction_counts_by_class(self, class_names, coocurring_only=False):
counts = {}
for i in range(len(self.prediction_sets)):
labels = self.prediction_sets[i]

# Get corresponding values from class_names
pset_class_names = [class_names[i] for i, label in enumerate(labels) if label == 1]

do_count = True
if coocurring_only:
do_count = len(pset_class_names) > 1

if do_count:
for class_name in pset_class_names:
class_counts = counts.get(class_name, 0)
class_counts += 1
counts[class_name] = class_counts
return counts

def run_reports(self, base_output_dir):
mean_set_sizes = self.mean_set_sizes_by_class(self.class_names)
mean_fnrs = self.mean_fnrs_by_class(self.prediction_sets, self.class_names)
Expand Down
18 changes: 18 additions & 0 deletions src/conformist/validation_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@ def mean_fnrs_by_class(self, class_names):
means[class_name] = statistics.mean(d_means)
return means

def mean_prediction_counts_by_class(self, class_names, coocurring_only=False):
prediction_count_dicts = []
for run in self.runs:
prediction_count_dicts.append(run.prediction_counts_by_class(class_names, coocurring_only))

means = {}
for class_name in class_names:
d_means = []
for d in prediction_count_dicts:
if class_name in d:
d_means.append(d[class_name])
if len(d_means) > 0:
means[class_name] = statistics.mean(d_means)

# Sort by count descending
means = dict(sorted(means.items(), key=lambda item: item[1], reverse=True))
return means

def run_reports(self, base_output_dir):
self.create_output_dir(base_output_dir)
self.visualize_empirical_fnr()
Expand Down

0 comments on commit e7a6221

Please # to comment.