diff --git a/src/conformist/base_cop.py b/src/conformist/base_cop.py index 64734bf..e133ba2 100644 --- a/src/conformist/base_cop.py +++ b/src/conformist/base_cop.py @@ -15,6 +15,11 @@ class BaseCoP(OutputDir): + FIGURE_FONTSIZE = 12 + FIGURE_WIDTH = 12 + FIGURE_HEIGHT = 8 + plt.rcParams.update({'font.size': FIGURE_FONTSIZE}) + def __init__(self, prediction_dataset: PredictionDataset, alpha=0.1): self.prediction_dataset = prediction_dataset self.alpha = alpha @@ -171,7 +176,10 @@ def predict(self, df = pd.DataFrame(stats, index=[0]) df.T.to_csv(f'{self.output_dir}/summary.csv', header=False) - return formatted_predictions, vr + if validate: + return formatted_predictions, vr + else: + return formatted_predictions def prediction_set_to_text(self, prediction_set, display_classes=None): class_names = self.class_names @@ -190,7 +198,8 @@ def prediction_sets_to_text(self, prediction_sets, display_classes=None): for prediction_set in prediction_sets] def upset_plot(self, predictions_sets, output_dir, color="black"): - plt.figure() + plt.figure(figsize=(self.FIGURE_WIDTH, + self.FIGURE_HEIGHT)) class_names = self.class_names