Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Suggestion: Replace _scan in batch.py with lax.scan #133

Open
mohamad-amin opened this issue Dec 11, 2021 · 0 comments
Open

Suggestion: Replace _scan in batch.py with lax.scan #133

mohamad-amin opened this issue Dec 11, 2021 · 0 comments
Labels
enhancement New feature or request

Comments

@mohamad-amin
Copy link

I was just reading through the file as I wanted to apply some modifications and I saw this function:

def _scan(f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]],

And this comment:

"""Implements an unrolled version of scan.
Based on jax.lax.scan and has a similar API.
TODO(schsam): We introduce this function because lax.scan currently has a
higher peak memory usage than the unrolled version. We will aim to swap this
out for lax.scan when issue #1273 and related have been resolved.
"""

Which is fixed:

jax-ml/jax#1273

@romanngg romanngg added the enhancement New feature or request label Jan 21, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants