diff --git a/kavian/eda/plot.py b/kavian/eda/plot.py index 5ddd12a..65343ff 100644 --- a/kavian/eda/plot.py +++ b/kavian/eda/plot.py @@ -205,21 +205,35 @@ def heatmap(dataframe, palette='kavian', subset=None): numerical = dataframe.select_dtypes(NUM) corr = numerical.corr() - font_size = 18 - len(numerical.columns) - annot_kws = {'size': font_size, 'fontweight': 'bold', 'fontstyle': 'italic'} + font_size = 18 - len(numerical.columns) + font_size = 8 if font_size < 8 else font_size cbar_kws = {'pad': 0.01} if palette == 'kavian': palette = sns.diverging_palette(18, 240, s=80, l=50, n=19, center="dark") - sns.heatmap(corr, ax=ax, cmap=palette, annot=True, annot_kws=annot_kws, - cbar_kws=cbar_kws, fmt='.2f', linecolor='black', linewidth=0.5, square=True) + too_many_cols = len(numerical.columns) >= 12 + + # Too many values for annotations to be legible + if too_many_cols: + sns.heatmap(corr, ax=ax, cmap=palette, cbar_kws=cbar_kws, + fmt='.2f', linecolor='black', linewidth=0.5, square=True) + + else: + annot_kws = {'size': font_size, 'fontweight': 'bold', 'fontstyle': 'italic'} + sns.heatmap(corr, ax=ax, cmap=palette, annot=True, annot_kws=annot_kws, + cbar_kws=cbar_kws, fmt='.2f', linecolor='black', linewidth=0.5, square=True) + + xticklabels = '' if too_many_cols else ax.get_xticklabels() ax.tick_params(rotation=20) - ax.set_xticklabels(ax.get_xticklabels(), fontweight='bold', fontstyle='italic') - ax.set_yticklabels(ax.get_yticklabels(), fontweight='bold', fontstyle='italic') + ax.set_xticklabels(xticklabels, fontweight='bold', fontstyle='italic', fontsize=font_size) + ax.set_yticklabels(ax.get_yticklabels(), fontweight='bold', fontstyle='italic', fontsize=font_size) ax.set_title(f'EDA Heatmap', fontdict={'fontsize': 24, 'fontfamily': 'serif'}) plt.tight_layout() plt.show() + + +