You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Implements an unrolled version of scan.
Based on jax.lax.scan and has a similar API.
TODO(schsam): We introduce this function because lax.scan currently has a
higher peak memory usage than the unrolled version. We will aim to swap this
out for lax.scan when issue #1273 and related have been resolved.
"""
I was just reading through the file as I wanted to apply some modifications and I saw this function:
neural-tangents/neural_tangents/utils/batch.py
Line 126 in 3deb197
And this comment:
Which is fixed:
jax-ml/jax#1273
The text was updated successfully, but these errors were encountered: