Skip to content

Commit 3ca9098

Browse files
authored
Visualisation: Allow specifying Agent shapes in agent_portrayal (#2214)
This PR allows specifying an `"shape"` in the `agent_portrayal` dictionary used by matplotlib component of the Solara visualisation. In short, it allows you represent an Agent in any [matplotlib marker](https://matplotlib.org/stable/api/markers_api.html), by adding a "shape" key-value pair to the `agent_portrayal` dictionary. This is especially useful when you're using the default shape drawer for grid or continuous space. For example: ```Python def agent_portrayal(cell): return { "color": "blue" "size": 5, "shape": "h" # marker is a hexagon! } ```
1 parent 3cf1b76 commit 3ca9098

File tree

3 files changed

+64
-25
lines changed

3 files changed

+64
-25
lines changed

docs/tutorials/visualization_tutorial.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@
142142
"source": [
143143
"#### Changing the agents\n",
144144
"\n",
145-
"In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n",
145+
"In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: Currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n",
146+
"In addition to size and color, an agent's shape can also be customized when using the default drawer. The allowed values for shapes can be found [here](https://matplotlib.org/stable/api/markers_api.html).\n",
146147
"\n",
147148
"To do this, we go back to our `agent_portrayal` code and add some code to change the portrayal based on the agent properties and launch the server again."
148149
]

mesa/visualization/components/matplotlib.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
import networkx as nx
24
import solara
35
from matplotlib.figure import Figure
@@ -23,12 +25,44 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non
2325
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)
2426

2527

28+
# matplotlib scatter does not allow for multiple shapes in one call
29+
def _split_and_scatter(portray_data, space_ax):
30+
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})
31+
32+
# Extract data from the dictionary
33+
x = portray_data["x"]
34+
y = portray_data["y"]
35+
s = portray_data["s"]
36+
c = portray_data["c"]
37+
m = portray_data["m"]
38+
39+
if not (len(x) == len(y) == len(s) == len(c) == len(m)):
40+
raise ValueError(
41+
"Length mismatch in portrayal data lists: "
42+
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
43+
f"color: {len(c)}, marker: {len(m)}"
44+
)
45+
46+
# Group the data by marker
47+
for i in range(len(x)):
48+
marker = m[i]
49+
grouped_data[marker]["x"].append(x[i])
50+
grouped_data[marker]["y"].append(y[i])
51+
grouped_data[marker]["s"].append(s[i])
52+
grouped_data[marker]["c"].append(c[i])
53+
54+
# Plot each group with the same marker
55+
for marker, data in grouped_data.items():
56+
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)
57+
58+
2659
def _draw_grid(space, space_ax, agent_portrayal):
2760
def portray(g):
2861
x = []
2962
y = []
3063
s = [] # size
3164
c = [] # color
65+
m = [] # shape
3266
for i in range(g.width):
3367
for j in range(g.height):
3468
content = g._grid[i][j]
@@ -41,23 +75,23 @@ def portray(g):
4175
data = agent_portrayal(agent)
4276
x.append(i)
4377
y.append(j)
44-
if "size" in data:
45-
s.append(data["size"])
46-
if "color" in data:
47-
c.append(data["color"])
48-
out = {"x": x, "y": y}
49-
# This is the default value for the marker size, which auto-scales
50-
# according to the grid area.
51-
out["s"] = (180 / max(g.width, g.height)) ** 2
52-
if len(s) > 0:
53-
out["s"] = s
54-
if len(c) > 0:
55-
out["c"] = c
78+
79+
# This is the default value for the marker size, which auto-scales
80+
# according to the grid area.
81+
default_size = (180 / max(g.width, g.height)) ** 2
82+
# establishing a default prevents misalignment if some agents are not given size, color, etc.
83+
size = data.get("size", default_size)
84+
s.append(size)
85+
color = data.get("color", "tab:blue")
86+
c.append(color)
87+
mark = data.get("shape", "o")
88+
m.append(mark)
89+
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
5690
return out
5791

5892
space_ax.set_xlim(-1, space.width)
5993
space_ax.set_ylim(-1, space.height)
60-
space_ax.scatter(**portray(space))
94+
_split_and_scatter(portray(space), space_ax)
6195

6296

6397
def _draw_network_grid(space, space_ax, agent_portrayal):
@@ -77,20 +111,23 @@ def portray(space):
77111
y = []
78112
s = [] # size
79113
c = [] # color
114+
m = [] # shape
80115
for agent in space._agent_to_index:
81116
data = agent_portrayal(agent)
82117
_x, _y = agent.pos
83118
x.append(_x)
84119
y.append(_y)
85-
if "size" in data:
86-
s.append(data["size"])
87-
if "color" in data:
88-
c.append(data["color"])
89-
out = {"x": x, "y": y}
90-
if len(s) > 0:
91-
out["s"] = s
92-
if len(c) > 0:
93-
out["c"] = c
120+
121+
# This is matplotlib's default marker size
122+
default_size = 20
123+
# establishing a default prevents misalignment if some agents are not given size, color, etc.
124+
size = data.get("size", default_size)
125+
s.append(size)
126+
color = data.get("color", "tab:blue")
127+
c.append(color)
128+
mark = data.get("shape", "o")
129+
m.append(mark)
130+
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
94131
return out
95132

96133
# Determine border style based on space.torus
@@ -110,7 +147,7 @@ def portray(space):
110147
space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
111148

112149
# Portray and scatter the agents in the space
113-
space_ax.scatter(**portray(space))
150+
_split_and_scatter(portray(space), space_ax)
114151

115152

116153
@solara.component

mesa/visualization/solara_viz.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def SolaraViz(
103103
model_params: Parameters for initializing the model
104104
measures: List of callables or data attributes to plot
105105
name: Name for display
106-
agent_portrayal: Options for rendering agents (dictionary)
106+
agent_portrayal: Options for rendering agents (dictionary);
107+
Default drawer supports custom `"size"`, `"color"`, and `"shape"`.
107108
space_drawer: Method to render the agent space for
108109
the model; default implementation is the `SpaceMatplotlib` component;
109110
simulations with no space to visualize should

0 commit comments

Comments
 (0)