Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Minor fixes to ParticleBeam.plot_distribution #330

Merged
merged 9 commits into from
Mar 4, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Python 3.9 is no longer supported. This does not immediately break existing code, but might cause it to break in the future. (see #325) (@jank324)
- The covariance properties of the different beam classes were renamed from names like `cor_x` and `sigma_xpx` to consistent names like `cov_xpx` (see #331) (@jank324)
- The signature of the `transfer_map` method of all element subclasses was extended by a non-optional `species` argument (see #276) (@cr-xu, @jank324, @Hespe)
- `ParticleBeam.plot_distribution` allows for Seaborn-style passing of `axs` and returns the latter as well. In line with that change for the purpose of overlaying distributions, the `contour` argument of `ParticleBeam.plot_2d_distribution` was replaced by a `style` argument. (see #330) (@jank324)

### 🚀 Features

Expand Down
60 changes: 29 additions & 31 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from cheetah.particles.species import Species
from cheetah.utils import (
elementwise_linspace,
format_axis_as_percentage,
format_axis_with_prefixed_unit,
unbiased_weighted_covariance,
unbiased_weighted_std,
Expand Down Expand Up @@ -1100,21 +1099,17 @@ def plot_1d_distribution(
# Handle units
if dimension in ("x", "y", "tau"):
base_unit = "m"
elif dimension in ("px", "py", "p"):
base_unit = "%"

if dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.xaxis, base_unit, centers)
elif dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.xaxis)

return ax

def plot_2d_distribution(
self,
x_dimension: Literal["x", "px", "y", "py", "tau", "p"],
y_dimension: Literal["x", "px", "y", "py", "tau", "p"],
contour: bool = False,
style: Literal["histogram", "contour"] = "histogram",
bins: int = 100,
bin_ranges: Optional[Tuple[Tuple[float]]] = None,
histogram_smoothing: float = 0.0,
Expand All @@ -1130,7 +1125,7 @@ def plot_2d_distribution(
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param y_dimension: Name of the y dimension to plot. Should be one of
`('x', 'px', 'y', 'py', 'tau', 'p')`.
:param contour: If `True`, overlay contour lines on the 2D histogram plot.
:param style: Style of the plot. Should be one of `('histogram', 'contour')`.
:param bins: Number of bins to use for the histogram in both dimensions.
:param bin_ranges: Ranges of the bins to use for the histogram in each
dimension.
Expand Down Expand Up @@ -1158,14 +1153,14 @@ def plot_2d_distribution(
# Post-process and plot
smoothed_histogram = gaussian_filter(histogram, histogram_smoothing)
clipped_histogram = np.where(smoothed_histogram > 1, smoothed_histogram, np.nan)
ax.pcolormesh(
x_edges,
y_edges,
clipped_histogram.T / smoothed_histogram.max(),
**{"cmap": "rainbow"} | (pcolormesh_kws or {}),
)

if contour:
if style == "histogram":
ax.pcolormesh(
x_edges,
y_edges,
clipped_histogram.T / smoothed_histogram.max(),
**{"cmap": "rainbow"} | (pcolormesh_kws or {}),
)
elif style == "contour":
contour_histogram = gaussian_filter(histogram, contour_smoothing)

ax.contour(
Expand All @@ -1181,23 +1176,15 @@ def plot_2d_distribution(
# Handle units
if x_dimension in ("x", "y", "tau"):
x_base_unit = "m"
elif x_dimension in ("px", "py", "p"):
x_base_unit = "%"

if y_dimension in ("x", "y", "tau"):
y_base_unit = "m"
elif y_dimension in ("px", "py", "p"):
y_base_unit = "%"

if x_dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.xaxis, x_base_unit, x_centers)
elif x_dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.xaxis)

if y_dimension in ("x", "y", "tau"):
format_axis_with_prefixed_unit(ax.yaxis, y_base_unit, y_centers)
elif y_dimension in ("px", "py", "p"):
format_axis_as_percentage(ax.yaxis)

return ax

Expand All @@ -1210,7 +1197,8 @@ def plot_distribution(
] = None,
plot_1d_kws: Optional[dict] = None,
plot_2d_kws: Optional[dict] = None,
) -> plt.Figure:
axs: Optional[List[plt.Axes]] = None,
) -> Tuple[plt.Figure, np.ndarray]:
"""
Plot of coordinates projected into 2D planes.

Expand All @@ -1227,13 +1215,23 @@ def plot_distribution(
`ParticleBeam.plot_1d_distribution` for plotting 1D histograms.
:param plot_2d_kws: Additional keyword arguments to be passed to
`ParticleBeam.plot_2d_distribution` for plotting 2D histograms.
:return: Matplotlib figure object.
:param axs: List of Matplotlib axes objects to use for plotting. If set to
`None`, a new figure is created. Must have the shape `(len(dimensions),
len(dimensions))`.
:return: Matplotlib figure and axes objects with the plot.
"""
fig, axs = plt.subplots(
len(dimensions),
len(dimensions),
figsize=(2 * len(dimensions), 2 * len(dimensions)),
)
if axs is None:
fig, axs = plt.subplots(
len(dimensions),
len(dimensions),
figsize=(2 * len(dimensions), 2 * len(dimensions)),
)
else:
fig = axs[0, 0].figure
assert axs.shape == (len(dimensions), len(dimensions)), (
"If `axs` is provided, it must have the shape "
f"`({len(dimensions)}, {len(dimensions)})`."
)

# Determine bin ranges for all plots in the grid at once
full_tensor = (
Expand Down Expand Up @@ -1348,7 +1346,7 @@ def plot_distribution(
axs[i, i].set_yticks([])
axs[i, i].set_ylabel(None)

return fig
return fig, axs

def plot_point_cloud(
self, scatter_kws: Optional[dict] = None, ax: Optional[plt.Axes] = None
Expand Down
6 changes: 4 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

import cheetah
Expand Down Expand Up @@ -140,12 +141,13 @@ def test_plotting_with_gradients():
segment.plot_twiss(incoming=beam)


def test_plot_6d_particle_beam_distribution():
@pytest.mark.parametrize("style", ["histogram", "contour"])
def test_plot_6d_particle_beam_distribution(style):
"""Test that the 6D `ParticleBeam` distribution plot does not raise an exception."""
beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001")

# Run the plotting to see if it raises an exception
_ = beam.plot_distribution(bin_ranges="unit_same", plot_2d_kws={"contour": True})
_ = beam.plot_distribution(bin_ranges="unit_same", plot_2d_kws={"style": style})


def test_plot_particle_beam_point_cloud():
Expand Down
Loading