diff --git a/src/conformist/validation_run.py b/src/conformist/validation_run.py index 60e9664..99216ff 100644 --- a/src/conformist/validation_run.py +++ b/src/conformist/validation_run.py @@ -118,10 +118,14 @@ def mean_fnrs_by_class(self, sets, class_names): return averages def run_reports(self, base_output_dir): + mean_set_sizes = self.mean_set_sizes_by_class(self.class_names) + mean_fnrs = self.mean_fnrs_by_class(self.prediction_sets, self.class_names) + mean_model_fnrs = self.mean_fnrs_by_class(self.model_predictions, self.class_names) + pr = PerformanceReport(base_output_dir) - pr.report_class_statistics(self.mean_set_sizes_by_class(self.class_names), - self.mean_fnrs_by_class(self.prediction_sets, self.class_names), - self.mean_fnrs_by_class(self.model_predictions, self.class_names)) + pr.report_class_statistics(mean_set_sizes, + mean_fnrs, + mean_model_fnrs) np.seterr(all='raise') self.create_output_dir(base_output_dir) @@ -140,3 +144,10 @@ def run_reports(self, base_output_dir): df.T.to_csv(f'{self.output_dir}/summary.csv', header=False) print(f'Reports saved to {self.output_dir}') + + stats_dict = { + 'mean_set_sizes': mean_set_sizes, + 'mean_fnrs': mean_fnrs, + 'mean_model_fnrs': mean_model_fnrs + } + return stats_dict