From c723f84c17b9ea279f78f28586277f61b2e07a8c Mon Sep 17 00:00:00 2001 From: giadarol Date: Fri, 2 Feb 2024 13:44:01 +0100 Subject: [PATCH 01/11] Better for thick bends --- xplt/line.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/xplt/line.py b/xplt/line.py index 616cc80..c9d06f4 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -19,6 +19,8 @@ from .util import defaults, get, defaults_for from .properties import Property, DataProperty +import xtrack as xt + def iter_elements(line): """Iterate over elements in line @@ -309,6 +311,7 @@ def update(self, survey, line=None, autoscale=False): NAME = get(survey, "name") BEND = get(survey, "angle") + # beam line ############ self.artist_beamline.set_data(X, Y) @@ -332,7 +335,7 @@ def update(self, survey, line=None, autoscale=False): legend_entries = [] for i, (x, y, rt, name, arc) in enumerate(zip(X, Y, R, NAME, BEND)): drift_length = get(survey, "drift_length", None) - if drift_length is not None and drift_length[i] > 0: + if drift_length is not None and drift_length[i] > 0 and isinstance(line[name], xt.Drift): continue # skip drift spaces if self.ignore is not None: if np.any([re.match(pattern, name) is not None for pattern in self.ignore]): @@ -382,6 +385,11 @@ def update(self, survey, line=None, autoscale=False): legend_entries.append(box_style.get("label")) if length > 0 and arc: + + if line[name].isthick: + x = 0.5 * (x + X[i + 1]) + y = 0.5 * (y + Y[i + 1]) + # bending elements as wedge rho = length / arc box = mpl.patches.Wedge( From 9f9a896874d46ae60171d0bd1ccbcfc62df6b254 Mon Sep 17 00:00:00 2001 From: giadarol Date: Fri, 2 Feb 2024 15:38:07 +0100 Subject: [PATCH 02/11] Colors --- xplt/line.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/xplt/line.py b/xplt/line.py index c9d06f4..b8fef77 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -351,6 +351,13 @@ def update(self, survey, line=None, autoscale=False): order = get(survey, "order", {i: order})[i] length = get(element, "length", None) length = get(survey, "length", {i: length})[i] + + # Patch order for thick elements + if name != '_end_point': + etype_name = line[name].__class__.__name__ + if etype_name in ORDER_NAMED_ELEMENTS: + order = ORDER_NAMED_ELEMENTS[etype_name] + if length is not None: length = length * scale @@ -523,6 +530,13 @@ def _get_config(config, name, **default): return default +ORDER_NAMED_ELEMENTS = { + 'Bend': 0, + 'Quadrupole': 1, + 'Sextupole': 2, + 'Octupole': 3, +} + ## Restrict star imports to local namespace __all__ = [ name From 2cdc24968aa99cef6268b432c4051d5d585e0076 Mon Sep 17 00:00:00 2001 From: giadarol Date: Fri, 2 Feb 2024 16:37:37 +0100 Subject: [PATCH 03/11] Center the bends --- xplt/line.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index b8fef77..6e0061c 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -393,12 +393,25 @@ def update(self, survey, line=None, autoscale=False): if length > 0 and arc: + rho = length / arc + if line[name].isthick: - x = 0.5 * (x + X[i + 1]) - y = 0.5 * (y + Y[i + 1]) + x_mid = 0.5 * (x + X[i + 1]) + y_mid = 0.5 * (y + Y[i + 1]) + + dr = np.array([X[i + 1] - x, Y[i + 1] - y, 0]) + dn = np.cross(dr, [0, 0, 1]) + dn /= np.linalg.norm(dn) + d = np.linalg.norm(dr)/2 + sin_theta = np.abs(d/rho) + dh = d * sin_theta + p_center = np.array([x_mid, y_mid, 0]) - np.sign(arc) * dh * dn + x = p_center[0] + y = p_center[1] + + # bending elements as wedge - rho = length / arc box = mpl.patches.Wedge( **defaults_for( mpl.patches.Wedge, From 9a67ecb5f69f70f78cb94d521370cfc3c051f8a8 Mon Sep 17 00:00:00 2001 From: giadarol Date: Fri, 2 Feb 2024 17:33:16 +0100 Subject: [PATCH 04/11] Support labels as list --- xplt/line.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index 6e0061c..45a7e83 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -249,6 +249,9 @@ def __init__( self.ignore = [ignore] if isinstance(ignore, str) else ignore self.element_width = element_width + if isinstance(self.labels, (list, tuple, np.ndarray)): + self.labels = '|'.join(self.labels) + # Create plot self.ax.set( xlabel=self.label_for(self.projection[0]), ylabel=self.label_for(self.projection[1]) @@ -396,9 +399,9 @@ def update(self, survey, line=None, autoscale=False): rho = length / arc if line[name].isthick: + # Find the center of the arc x_mid = 0.5 * (x + X[i + 1]) y_mid = 0.5 * (y + Y[i + 1]) - dr = np.array([X[i + 1] - x, Y[i + 1] - y, 0]) dn = np.cross(dr, [0, 0, 1]) dn /= np.linalg.norm(dn) @@ -409,8 +412,6 @@ def update(self, survey, line=None, autoscale=False): x = p_center[0] y = p_center[1] - - # bending elements as wedge box = mpl.patches.Wedge( **defaults_for( @@ -432,6 +433,10 @@ def update(self, survey, line=None, autoscale=False): ) else: + if line[name].isthick: + x = 0.5 * (x + X[i + 1]) + y = 0.5 * (y + Y[i + 1]) + # other elements as rect box = mpl.patches.Rectangle( **defaults_for( From 64bf0d882f2c53b1d0c83b9c24b1b21c46eb7474 Mon Sep 17 00:00:00 2001 From: giadarol Date: Fri, 2 Feb 2024 17:58:26 +0100 Subject: [PATCH 05/11] Check for exact string when labels are from list --- xplt/line.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xplt/line.py b/xplt/line.py index 45a7e83..8268ffd 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -250,7 +250,7 @@ def __init__( self.element_width = element_width if isinstance(self.labels, (list, tuple, np.ndarray)): - self.labels = '|'.join(self.labels) + self.labels = '|'.join(['^' + ll + '$' for ll in self.labels]) # Create plot self.ax.set( From 9416f2435dbd1245252d13bf31ac8263f6bd4bcf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Feb 2024 17:01:18 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xplt/line.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index 8268ffd..0ec48f6 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -250,7 +250,7 @@ def __init__( self.element_width = element_width if isinstance(self.labels, (list, tuple, np.ndarray)): - self.labels = '|'.join(['^' + ll + '$' for ll in self.labels]) + self.labels = "|".join(["^" + ll + "$" for ll in self.labels]) # Create plot self.ax.set( @@ -314,7 +314,6 @@ def update(self, survey, line=None, autoscale=False): NAME = get(survey, "name") BEND = get(survey, "angle") - # beam line ############ self.artist_beamline.set_data(X, Y) @@ -338,7 +337,11 @@ def update(self, survey, line=None, autoscale=False): legend_entries = [] for i, (x, y, rt, name, arc) in enumerate(zip(X, Y, R, NAME, BEND)): drift_length = get(survey, "drift_length", None) - if drift_length is not None and drift_length[i] > 0 and isinstance(line[name], xt.Drift): + if ( + drift_length is not None + and drift_length[i] > 0 + and isinstance(line[name], xt.Drift) + ): continue # skip drift spaces if self.ignore is not None: if np.any([re.match(pattern, name) is not None for pattern in self.ignore]): @@ -356,7 +359,7 @@ def update(self, survey, line=None, autoscale=False): length = get(survey, "length", {i: length})[i] # Patch order for thick elements - if name != '_end_point': + if name != "_end_point": etype_name = line[name].__class__.__name__ if etype_name in ORDER_NAMED_ELEMENTS: order = ORDER_NAMED_ELEMENTS[etype_name] @@ -405,8 +408,8 @@ def update(self, survey, line=None, autoscale=False): dr = np.array([X[i + 1] - x, Y[i + 1] - y, 0]) dn = np.cross(dr, [0, 0, 1]) dn /= np.linalg.norm(dn) - d = np.linalg.norm(dr)/2 - sin_theta = np.abs(d/rho) + d = np.linalg.norm(dr) / 2 + sin_theta = np.abs(d / rho) dh = d * sin_theta p_center = np.array([x_mid, y_mid, 0]) - np.sign(arc) * dh * dn x = p_center[0] @@ -549,10 +552,10 @@ def _get_config(config, name, **default): ORDER_NAMED_ELEMENTS = { - 'Bend': 0, - 'Quadrupole': 1, - 'Sextupole': 2, - 'Octupole': 3, + "Bend": 0, + "Quadrupole": 1, + "Sextupole": 2, + "Octupole": 3, } ## Restrict star imports to local namespace From cf82aff93b0b542dd6d9269790f1630e8d7cea8a Mon Sep 17 00:00:00 2001 From: giadarol Date: Sat, 3 Feb 2024 13:21:37 +0100 Subject: [PATCH 07/11] No rcparam alteration --- xplt/hooks.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/xplt/hooks.py b/xplt/hooks.py index 1161e26..7b019c1 100644 --- a/xplt/hooks.py +++ b/xplt/hooks.py @@ -40,17 +40,17 @@ def register_matplotlib_options(): cmap_r.name = cmap.name + "_r" mpl.cm.register_cmap(cmap=cmap_r) - # set rcParams - mpl.rcParams.update( - { - "figure.constrained_layout.use": True, - "legend.fontsize": "x-small", - "legend.title_fontsize": "small", - "grid.color": "#DDD", - "axes.prop_cycle": mpl.cycler(color=petroff_colors), - # 'image.cmap': cmap_petroff_gradient, - } - ) + # # set rcParams + # mpl.rcParams.update( + # { + # "figure.constrained_layout.use": True, + # "legend.fontsize": "x-small", + # "legend.title_fontsize": "small", + # "grid.color": "#DDD", + # "axes.prop_cycle": mpl.cycler(color=petroff_colors), + # # 'image.cmap': cmap_petroff_gradient, + # } + # ) def register_pint_options(): From 411670f7b64bb9f1a6b50933aeef457c4181f062 Mon Sep 17 00:00:00 2001 From: giadarol Date: Mon, 5 Feb 2024 17:50:13 +0100 Subject: [PATCH 08/11] Avoid import of xtrack --- xplt/line.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index 8268ffd..ec4c5dd 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -19,9 +19,6 @@ from .util import defaults, get, defaults_for from .properties import Property, DataProperty -import xtrack as xt - - def iter_elements(line): """Iterate over elements in line @@ -338,7 +335,8 @@ def update(self, survey, line=None, autoscale=False): legend_entries = [] for i, (x, y, rt, name, arc) in enumerate(zip(X, Y, R, NAME, BEND)): drift_length = get(survey, "drift_length", None) - if drift_length is not None and drift_length[i] > 0 and isinstance(line[name], xt.Drift): + if (drift_length is not None and drift_length[i] > 0 + and line[name].__class__.__name__ == 'Drift'): continue # skip drift spaces if self.ignore is not None: if np.any([re.match(pattern, name) is not None for pattern in self.ignore]): From c5810bfae7b12954538ebaed5dcc3274e818c647 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:55:21 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xplt/line.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index 4c86083..b010e44 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -19,6 +19,7 @@ from .util import defaults, get, defaults_for from .properties import Property, DataProperty + def iter_elements(line): """Iterate over elements in line @@ -334,8 +335,11 @@ def update(self, survey, line=None, autoscale=False): legend_entries = [] for i, (x, y, rt, name, arc) in enumerate(zip(X, Y, R, NAME, BEND)): drift_length = get(survey, "drift_length", None) - if (drift_length is not None and drift_length[i] > 0 - and line[name].__class__.__name__ == 'Drift'): + if ( + drift_length is not None + and drift_length[i] > 0 + and line[name].__class__.__name__ == "Drift" + ): continue # skip drift spaces if self.ignore is not None: if np.any([re.match(pattern, name) is not None for pattern in self.ignore]): From 73e41071b8368e114d79e9f7b1441c204a84f98c Mon Sep 17 00:00:00 2001 From: Philipp Niedermayer Date: Mon, 5 Feb 2024 18:33:06 +0100 Subject: [PATCH 10/11] Support line==None for plain MAD-X and some generalisations --- xplt/line.py | 64 +++++++++++++++++++++++----------------------------- 1 file changed, 28 insertions(+), 36 deletions(-) diff --git a/xplt/line.py b/xplt/line.py index b010e44..ea6015e 100644 --- a/xplt/line.py +++ b/xplt/line.py @@ -205,10 +205,10 @@ def __init__( projection (str): The projection to use: A pair of coordinates ('XZ', 'ZY' etc.) line (xtrack.Line): Line data with additional information about elements. Use this to have colored boxes of correct size etc. - boxes (None | bool | str | dict): Config option for showing colored boxes for elements. See below. + boxes (None | bool | str | iterable | dict): Config option for showing colored boxes for elements. See below. Detailed options can be "length" and all options suitable for a patch, such as "color", "alpha", etc. - labels (None | bool | str | dict): Config option for showing labels for elements. See below. + labels (None | bool | str | iterable | dict): Config option for showing labels for elements. See below. Detailed options can be "text" (e.g. "Dipole {name}" where name will be replaced with the element name) and all options suitable for an annotation, such as "color", "alpha", etc. @@ -221,6 +221,7 @@ def __init__( - None: Use good defaults. - A bool: En-/disable option for all elements (except drifts). - A str (regex): Filter by element name. + - A list, tuple or numpy array: Filter by any of the given element names - A dict: Detailed options to apply for each element in the form of `{"regex": {...}}`. For each matching element name, the options are used. @@ -247,6 +248,8 @@ def __init__( self.ignore = [ignore] if isinstance(ignore, str) else ignore self.element_width = element_width + if isinstance(self.boxes, (list, tuple, np.ndarray)): + self.boxes = "|".join(["^" + ll + "$" for ll in self.boxes]) if isinstance(self.labels, (list, tuple, np.ndarray)): self.labels = "|".join(["^" + ll + "$" for ll in self.labels]) @@ -275,7 +278,7 @@ def update(self, survey, line=None, autoscale=False): Args: survey (Any): Survey data. - line (xtrack.Line): Line data. + line (None | xtrack.Line): Line data. autoscale (bool): Whether or not to perform autoscaling on all axes Returns: @@ -335,11 +338,8 @@ def update(self, survey, line=None, autoscale=False): legend_entries = [] for i, (x, y, rt, name, arc) in enumerate(zip(X, Y, R, NAME, BEND)): drift_length = get(survey, "drift_length", None) - if ( - drift_length is not None - and drift_length[i] > 0 - and line[name].__class__.__name__ == "Drift" - ): + is_thick = line is not None and name in line.element_dict and line[name].isthick + if drift_length is not None and drift_length[i] > 0 and not is_thick: continue # skip drift spaces if self.ignore is not None: if np.any([re.match(pattern, name) is not None for pattern in self.ignore]): @@ -353,15 +353,11 @@ def update(self, survey, line=None, autoscale=False): element = line.element_dict.get(name) if line is not None else None order = get(element, "order", None) order = get(survey, "order", {i: order})[i] + if line is not None and name in line.element_dict: + order = ORDER_NAMED_ELEMENTS.get(line[name].__class__.__name__, order) + length = get(element, "length", None) length = get(survey, "length", {i: length})[i] - - # Patch order for thick elements - if name != "_end_point": - etype_name = line[name].__class__.__name__ - if etype_name in ORDER_NAMED_ELEMENTS: - order = ORDER_NAMED_ELEMENTS[etype_name] - if length is not None: length = length * scale @@ -395,25 +391,24 @@ def update(self, survey, line=None, autoscale=False): else: legend_entries.append(box_style.get("label")) - if length > 0 and arc: - - rho = length / arc - - if line[name].isthick: - # Find the center of the arc - x_mid = 0.5 * (x + X[i + 1]) - y_mid = 0.5 * (y + Y[i + 1]) - dr = np.array([X[i + 1] - x, Y[i + 1] - y, 0]) - dn = np.cross(dr, [0, 0, 1]) - dn /= np.linalg.norm(dn) - d = np.linalg.norm(dr) / 2 - sin_theta = np.abs(d / rho) - dh = d * sin_theta - p_center = np.array([x_mid, y_mid, 0]) - np.sign(arc) * dh * dn - x = p_center[0] - y = p_center[1] + # Handle thick elements + if is_thick and i + 1 < len(X): + # Find the center of the arc + x_mid = 0.5 * (x + X[i + 1]) + y_mid = 0.5 * (y + Y[i + 1]) + dr = np.array([X[i + 1] - x, Y[i + 1] - y, 0]) + dn = np.cross(dr, [0, 0, 1]) + dn /= np.linalg.norm(dn) + d = np.linalg.norm(dr) / 2 + sin_theta = np.abs(d * arc / length) + dh = d * sin_theta + p_center = np.array([x_mid, y_mid, 0]) - helicity * dh * dn + x = p_center[0] + y = p_center[1] + if length > 0 and arc: # bending elements as wedge + rho = length / arc box = mpl.patches.Wedge( **defaults_for( mpl.patches.Wedge, @@ -434,10 +429,6 @@ def update(self, survey, line=None, autoscale=False): ) else: - if line[name].isthick: - x = 0.5 * (x + X[i + 1]) - y = 0.5 * (y + Y[i + 1]) - # other elements as rect box = mpl.patches.Rectangle( **defaults_for( @@ -549,6 +540,7 @@ def _get_config(config, name, **default): return default +# Known class names from xtrack and their order ORDER_NAMED_ELEMENTS = { "Bend": 0, "Quadrupole": 1, From 892d31fe0a5cc651e3613fd6bd0822b418fabfcc Mon Sep 17 00:00:00 2001 From: Philipp Niedermayer Date: Mon, 5 Feb 2024 19:02:05 +0100 Subject: [PATCH 11/11] Add matplotlib style sheet --- docs/usage.md | 1 + xplt/__init__.py | 4 ++++ xplt/hooks.py | 12 ------------ xplt/xplt.mplstyle | 5 +++++ 4 files changed, 10 insertions(+), 12 deletions(-) create mode 100644 xplt/xplt.mplstyle diff --git a/docs/usage.md b/docs/usage.md index 998e311..2d380bb 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -22,6 +22,7 @@ Xsuite is not an explicit dependency, rather an API assumption on available attr ```python import xplt +xplt.apply_style() # use our matplotlib style sheet import numpy as np import pandas as pd diff --git a/xplt/__init__.py b/xplt/__init__.py index 9c90922..4ee398d 100644 --- a/xplt/__init__.py +++ b/xplt/__init__.py @@ -38,3 +38,7 @@ hooks.register_pint_options() except: pass + + +def apply_style(): + mpl.style.use("xplt.xplt") diff --git a/xplt/hooks.py b/xplt/hooks.py index 7b019c1..e25a664 100644 --- a/xplt/hooks.py +++ b/xplt/hooks.py @@ -40,18 +40,6 @@ def register_matplotlib_options(): cmap_r.name = cmap.name + "_r" mpl.cm.register_cmap(cmap=cmap_r) - # # set rcParams - # mpl.rcParams.update( - # { - # "figure.constrained_layout.use": True, - # "legend.fontsize": "x-small", - # "legend.title_fontsize": "small", - # "grid.color": "#DDD", - # "axes.prop_cycle": mpl.cycler(color=petroff_colors), - # # 'image.cmap': cmap_petroff_gradient, - # } - # ) - def register_pint_options(): """Register default options for pint""" diff --git a/xplt/xplt.mplstyle b/xplt/xplt.mplstyle new file mode 100644 index 0000000..8dc2fcc --- /dev/null +++ b/xplt/xplt.mplstyle @@ -0,0 +1,5 @@ +figure.constrained_layout.use: True +legend.fontsize: x-small +legend.title_fontsize: small +grid.color: "#DDD" +axes.prop_cycle: cycler('color', ['pet0', 'pet1', 'pet2', 'pet3', 'pet4', 'pet5', 'pet6', 'pet7', 'pet8', 'pet9'])