Skip to content

Commit

Permalink
Add mean set size to heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 21, 2024
1 parent 1c9a269 commit 10512ef
Showing 1 changed file with 84 additions and 13 deletions.
97 changes: 84 additions & 13 deletions src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.font_manager import FontProperties
import matplotlib.gridspec as gridspec
from upsetplot import plot

from .output_dir import OutputDir
Expand Down Expand Up @@ -328,51 +329,117 @@ def visualize_class_counts_by_dataset(self,
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',
bbox_inches='tight')

def visualize_prediction_heatmap(self):
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10))

# Create two subplots, we will create two separate heatmaps
# Define a GridSpec with unequal column widths
# Column 1 is nx as wide as column 2
width_ratio = 8
gs = gridspec.GridSpec(1, 2, width_ratios=[width_ratio, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])

SET_SIZE_COL_NAME = "mean set size"

group_by_col = self.MELTED_KNOWN_CLASS_COL
df = self.melt()

grouped_df = df.groupby(group_by_col)
pred_col_names = self.class_names()
set_size_col_names = [SET_SIZE_COL_NAME]

psets = self._get_prediction_sets_by_softmax_threshold(
min_softmax_threshold)
# Index psets by df.id
psets.set_index(df[self.ID_COL], inplace=True)
# Set new column set_size that contains the number of True values in other cols
psets['set_size'] = psets.sum(axis=1)
# Add melted known class to psets
psets = psets.merge(df[[self.ID_COL, group_by_col]],
left_index=True,
right_on=self.ID_COL)

# Create new df with columns class and mean_set_size
psets_df = psets.groupby(
self.MELTED_KNOWN_CLASS_COL)['set_size'].mean()

mean_smx = []
mean_set_size = []

for name, group in grouped_df:
name = self.translate_class_name(name)
mean_smx_row = [name]
mean_set_size_row = [name]

for col in pred_col_names:
mean_smx_row.append(group[col].mean())
mean_set_size_row.append(psets_df[name])

mean_smx.append(mean_smx_row)
mean_set_size.append(mean_set_size_row)

col_names = ['true_class_name'] + self.class_names(translate=True)
col_names_1 = ['true_class_name'] + pred_col_names
col_names_2 = ['true_class_name'] + set_size_col_names

mean_smx_df = pd.DataFrame(mean_smx, columns=col_names)
mean_smx_df = pd.DataFrame(mean_smx, columns=col_names_1)
mean_smx_df.set_index('true_class_name', inplace=True)

mean_set_size_df = pd.DataFrame(mean_set_size, columns=col_names_2)
mean_set_size_df.set_index('true_class_name', inplace=True)

# Sort the rows and columns
mean_smx_df.sort_index(axis=0, inplace=True) # Sort rows
mean_smx_df.sort_index(axis=1, inplace=True) # Sort columns

mean_set_size_df.sort_index(axis=0, inplace=True) # Sort rows
mean_set_size_df.sort_index(axis=1, inplace=True) # Sort columns

# Remove any columns where all the rows are 0
mean_smx_df = mean_smx_df.loc[:, (mean_smx_df != 0).any(axis=0)]

hm = sns.heatmap(mean_smx_df,
ax=ax1,
cmap="coolwarm",
annot=True,
fmt='.2f')
fmt='.2f',
cbar=False)

labelpad = 20
plt.setp(hm.get_yticklabels(), rotation=0)

hm.set_xlabel('MEAN PROBABILITY SCORE',
hm.set_xlabel('MEAN SOFTMAX SCORE',
weight='bold', labelpad=labelpad)
hm.set_ylabel('TRUE CLASS',
weight='bold', labelpad=labelpad)

# Create second heatmap for mean set size
hm2 = sns.heatmap(mean_set_size_df,
ax=ax2,
cmap=sns.light_palette("purple", as_cmap=True),
annot=True,
fmt='.2f',
cbar=False)

# Rotate x labels
plt.setp(hm2.get_xticklabels(), rotation=90)

# Set y label
hm2.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}",
weight='bold', labelpad=labelpad)

# Position y label to the right of heatmap
hm2.yaxis.set_label_position("right")

# Remove y ticks
hm2.set_yticks([])

# Remove x label
hm2.set_xlabel('')

# Remove x ticks
hm2.set_xticks([])

plt.tight_layout(w_pad=0.1) # Control padding

# Save the plot to a file
plt.savefig(f'{self.output_dir}/prediction_heatmap.png', bbox_inches='tight')

Expand Down Expand Up @@ -507,10 +574,7 @@ def visualize_prediction_stripplot(self,
plt.savefig(f'{self.output_dir}/prediction_stripplot.png',
bbox_inches='tight')

def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
plt.figure()
plt.figure(figsize=(self.FIGURE_WIDTH, 8))

def _get_prediction_sets_by_softmax_threshold(self, min_softmax_threshold=0.5):
df = self.melt()
cols = [col for col in df.columns if col in self.class_names()]

Expand All @@ -523,7 +587,14 @@ def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
new_row[col] = (row[col] >= min_softmax_threshold)
rows.append(new_row)

upset_data = pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True)
return pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True)


def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
plt.figure()
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
upset_data = self._get_prediction_sets_by_softmax_threshold(
min_softmax_threshold)
# Set a multi-index
upset_data.set_index(upset_data.columns.tolist(), inplace=True)

Expand Down

0 comments on commit 10512ef

Please # to comment.