From 10512ef9a4e4bfa4de3a3f2e2a8f7bfe454d85a4 Mon Sep 17 00:00:00 2001 From: Mariya Lysenkova Wiklander Date: Thu, 21 Nov 2024 17:27:57 +0100 Subject: [PATCH] Add mean set size to heatmap --- src/conformist/prediction_dataset.py | 97 ++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 13 deletions(-) diff --git a/src/conformist/prediction_dataset.py b/src/conformist/prediction_dataset.py index 02158d0..a212579 100644 --- a/src/conformist/prediction_dataset.py +++ b/src/conformist/prediction_dataset.py @@ -4,6 +4,7 @@ import seaborn as sns from matplotlib.patches import Patch from matplotlib.font_manager import FontProperties +import matplotlib.gridspec as gridspec from upsetplot import plot from .output_dir import OutputDir @@ -328,51 +329,117 @@ def visualize_class_counts_by_dataset(self, plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png', bbox_inches='tight') - def visualize_prediction_heatmap(self): - plt.figure(figsize=(self.FIGURE_WIDTH, 8)) + def visualize_prediction_heatmap(self, min_softmax_threshold=0.5): + fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10)) + + # Create two subplots, we will create two separate heatmaps + # Define a GridSpec with unequal column widths + # Column 1 is nx as wide as column 2 + width_ratio = 8 + gs = gridspec.GridSpec(1, 2, width_ratios=[width_ratio, 1]) + ax1 = fig.add_subplot(gs[0, 0]) + ax2 = fig.add_subplot(gs[0, 1]) + + SET_SIZE_COL_NAME = "mean set size" group_by_col = self.MELTED_KNOWN_CLASS_COL df = self.melt() - grouped_df = df.groupby(group_by_col) pred_col_names = self.class_names() + set_size_col_names = [SET_SIZE_COL_NAME] + + psets = self._get_prediction_sets_by_softmax_threshold( + min_softmax_threshold) + # Index psets by df.id + psets.set_index(df[self.ID_COL], inplace=True) + # Set new column set_size that contains the number of True values in other cols + psets['set_size'] = psets.sum(axis=1) + # Add melted known class to psets + psets = psets.merge(df[[self.ID_COL, group_by_col]], + left_index=True, + right_on=self.ID_COL) + + # Create new df with columns class and mean_set_size + psets_df = psets.groupby( + self.MELTED_KNOWN_CLASS_COL)['set_size'].mean() mean_smx = [] + mean_set_size = [] for name, group in grouped_df: - name = self.translate_class_name(name) mean_smx_row = [name] + mean_set_size_row = [name] for col in pred_col_names: mean_smx_row.append(group[col].mean()) + mean_set_size_row.append(psets_df[name]) mean_smx.append(mean_smx_row) + mean_set_size.append(mean_set_size_row) - col_names = ['true_class_name'] + self.class_names(translate=True) + col_names_1 = ['true_class_name'] + pred_col_names + col_names_2 = ['true_class_name'] + set_size_col_names - mean_smx_df = pd.DataFrame(mean_smx, columns=col_names) + mean_smx_df = pd.DataFrame(mean_smx, columns=col_names_1) mean_smx_df.set_index('true_class_name', inplace=True) + mean_set_size_df = pd.DataFrame(mean_set_size, columns=col_names_2) + mean_set_size_df.set_index('true_class_name', inplace=True) + # Sort the rows and columns mean_smx_df.sort_index(axis=0, inplace=True) # Sort rows mean_smx_df.sort_index(axis=1, inplace=True) # Sort columns + mean_set_size_df.sort_index(axis=0, inplace=True) # Sort rows + mean_set_size_df.sort_index(axis=1, inplace=True) # Sort columns + # Remove any columns where all the rows are 0 mean_smx_df = mean_smx_df.loc[:, (mean_smx_df != 0).any(axis=0)] hm = sns.heatmap(mean_smx_df, + ax=ax1, cmap="coolwarm", annot=True, - fmt='.2f') + fmt='.2f', + cbar=False) labelpad = 20 plt.setp(hm.get_yticklabels(), rotation=0) - hm.set_xlabel('MEAN PROBABILITY SCORE', + hm.set_xlabel('MEAN SOFTMAX SCORE', weight='bold', labelpad=labelpad) hm.set_ylabel('TRUE CLASS', weight='bold', labelpad=labelpad) + # Create second heatmap for mean set size + hm2 = sns.heatmap(mean_set_size_df, + ax=ax2, + cmap=sns.light_palette("purple", as_cmap=True), + annot=True, + fmt='.2f', + cbar=False) + + # Rotate x labels + plt.setp(hm2.get_xticklabels(), rotation=90) + + # Set y label + hm2.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}", + weight='bold', labelpad=labelpad) + + # Position y label to the right of heatmap + hm2.yaxis.set_label_position("right") + + # Remove y ticks + hm2.set_yticks([]) + + # Remove x label + hm2.set_xlabel('') + + # Remove x ticks + hm2.set_xticks([]) + + plt.tight_layout(w_pad=0.1) # Control padding + # Save the plot to a file plt.savefig(f'{self.output_dir}/prediction_heatmap.png', bbox_inches='tight') @@ -507,10 +574,7 @@ def visualize_prediction_stripplot(self, plt.savefig(f'{self.output_dir}/prediction_stripplot.png', bbox_inches='tight') - def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"): - plt.figure() - plt.figure(figsize=(self.FIGURE_WIDTH, 8)) - + def _get_prediction_sets_by_softmax_threshold(self, min_softmax_threshold=0.5): df = self.melt() cols = [col for col in df.columns if col in self.class_names()] @@ -523,7 +587,14 @@ def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"): new_row[col] = (row[col] >= min_softmax_threshold) rows.append(new_row) - upset_data = pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True) + return pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True) + + + def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"): + plt.figure() + plt.figure(figsize=(self.FIGURE_WIDTH, 8)) + upset_data = self._get_prediction_sets_by_softmax_threshold( + min_softmax_threshold) # Set a multi-index upset_data.set_index(upset_data.columns.tolist(), inplace=True)