diff --git a/feat/plotting.py b/feat/plotting.py
index fa7c736c..53e7773b 100644
--- a/feat/plotting.py
+++ b/feat/plotting.py
@@ -329,7 +329,7 @@ def draw_vectorfield(
     return ax
 
 
-def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs):
+def draw_muscles(currx, curry, au=None, ax=None, cmap="Blues", *args, **kwargs):
     """Draw Muscles
 
     Args:
@@ -756,7 +756,7 @@ def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs):
                 del kwargs[muscle]
     for muscle in todraw.keys():
         if todraw[muscle] == "heatmap":
-            muscles[muscle].set_color(get_heat(muscle, au, facet))
+            muscles[muscle].set_color(get_heat(muscle, au, facet, cmap))
         else:
             muscles[muscle].set_color(todraw[muscle])
         ax.add_patch(muscles[muscle], *args, **kwargs)
@@ -805,7 +805,7 @@ def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs):
     return ax
 
 
-def get_heat(muscle, au, log):
+def get_heat(muscle, au, log, cmap="Blues"):
     """Function to create heatmap from au vector
 
     Args:
@@ -817,7 +817,7 @@ def get_heat(muscle, au, log):
     Returns:
         color of muscle according to its au value
     """
-    q = sns.color_palette("Blues", 151)
+    q = sns.color_palette(cmap, 151)
     unit = 0
     aus = {
         "masseter_l": 15,
@@ -882,6 +882,7 @@ def plot_face(
     muscles=None,
     ax=None,
     feature_range=False,
+    cmap="Blues",
     color="k",
     linewidth=1,
     linestyle="-",
@@ -937,7 +938,7 @@ def plot_face(
             au = minmax_scale(au, feature_range=(0, 100 * muscle_scaler))
         else:
             au = muscle_scaler.transform(np.array(au).reshape(-1, 1)).squeeze()
-        ax = draw_muscles(currx, curry, ax=ax, au=au, **muscles)
+        ax = draw_muscles(currx, curry, cmap, ax=ax, au=au, **muscles)
 
     if gaze is not None and len((gaze)) != 4:
         warnings.warn(