Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Feature request: Support vectorisation for Screen.reading histogram #351

Open
amylizzle opened this issue Feb 24, 2025 · 3 comments
Open
Labels
enhancement New feature or request

Comments

@amylizzle
Copy link

Screens are great outputs for training against, and doing them in batches would be preferred!

@jank324
Copy link
Member

jank324 commented Feb 24, 2025

Cheetah is fully vectorised. Unless there is a bug, it should therefore be no problem to get a batch of screen readings by, for example, supplying a vector of magnet settings with something like this:

segment = cheetah.Segment(
    [
        cheetah.HorizontalCorrector(angle=torch.tensor([1e-5, 2e-5, 3e-5]), length=torch.tensor(0.15)),
        cheetah.Drift(length=torch.tensor(1.0)),
        cheetah.Screen(resolution=(200, 100), is_active=True, name="my_screen"),
    ],
)
outgoing_beam = segment.track(incoming_beam)

segment.my_screen.reading.shape   # Should be (3, 100, 200)

@amylizzle
Copy link
Author

Odd, I went to do a batch of screens for a quad with a k1 set to a tensor(5,1) and got:

File [/opt/conda/lib/python3.12/site-packages/cheetah/accelerator/screen.py:252](https://apml1.dl.ac.uk/opt/conda/lib/python3.12/site-packages/cheetah/accelerator/screen.py#line=251), in Screen.reading(self)
    245 if self.method == "histogram":
    246     # Catch vectorisation, which is currently not supported by "histogram"
    247     if (
    248         len(read_beam.particles.shape) > 2
    249         or len(read_beam.particle_charges.shape) > 1
    250         or len(read_beam.energy.shape) > 0
    251     ):
--> 252         raise NotImplementedError(
    253             "The `'histogram'` method of `Screen` does not support "
    254             "vectorization. Use `'kde'` instead. If this is a feature you "
    255             "would like to see, please open an issue on GitHub."
    256         )
    258     image, _ = torch.histogramdd(
    259         torch.stack((read_beam.x, read_beam.y)).T,
    260         bins=self.pixel_bin_edges,
    261         weight=read_beam.particle_charges
    262         * read_beam.survival_probabilities,
    263     )
    264     image = torch.flipud(image.T)

NotImplementedError: The `'histogram'` method of `Screen` does not support vectorization. Use `'kde'` instead. If this is a feature you would like to see, please open an issue on GitHub.

@jank324
Copy link
Member

jank324 commented Feb 24, 2025

Oh, you're right. I forgot that when method="histogram" vectorisation is not supported. This is unfortunately a result of the histogram method in PyTorch not supporting vectorisation. If you set method="kde", vectorisation should work though.

Of course, as soon as PyTorch supports this, we would like to add it to Cheetah as well (or if someone has a good hand-made implementation).

@jank324 jank324 added the enhancement New feature or request label Feb 24, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants