Skip to content

Commit

Permalink
Support error checking in auto mode outside jit
Browse files Browse the repository at this point in the history
This PR support error checking in auto mode outside jit. It does not cover auto mode inside jit, because it’s technically impossible.

This PR also improves the existing test for scan.

PiperOrigin-RevId: 726939419
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Feb 15, 2025
1 parent df135d2 commit 20f7e5f
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 7 deletions.
86 changes: 80 additions & 6 deletions jax/_src/error_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@

from __future__ import annotations

import contextlib
from functools import partial
import threading

import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax._src.mesh as mesh_lib
from jax._src.sharding_impls import (
NamedSharding,
PartitionSpec as P,
SingleDeviceSharding,
)
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp


Expand All @@ -36,8 +45,8 @@ class JaxValueError(ValueError):
_NO_ERROR = jnp.iinfo(jnp.uint32).max
"""The default error code for no error.
We choose this value because when performing reductions, we can use `min` to
obtain the smallest error code.
This value is chosen because we can simply use `jnp.min()` to obtain the
smallest error code when performing reductions.
"""


Expand All @@ -53,11 +62,53 @@ def _initialize_error_code_ref() -> None:
_error_code_ref = core.mutable_array(error_code)


@contextlib.contextmanager
def error_checking_context(
mesh: mesh_lib.AbstractMesh | mesh_lib.Mesh | tuple[()] | None = None,
):
"""Redefine the error checking state based on the mesh.
This contaxt manager should be used to when starting a multi-device
computation, and whenever the mesh is changed.
When exiting the context, the error checking state will be reset to the
original state.
"""
global _error_code_ref
old_error_code_ref = _error_code_ref

# If mesh is not provided, get the abstract mesh from the context.
if mesh is None:
mesh = mesh_lib.get_abstract_mesh()

if mesh == (): # single-device case.
with core.eval_context():
error_code = jnp.uint32(_NO_ERROR)
_error_code_ref = core.mutable_array(error_code)

else: # multi-device case.
sharding = NamedSharding(mesh, P(*mesh.axis_names)) # type: ignore
# print([sharding.mesh._name_to_type[s] for s in sharding.spec])
with core.eval_context():
error_code = jnp.full(
mesh.axis_sizes, # type: ignore
jnp.uint32(_NO_ERROR),
device=sharding,
)
_error_code_ref = core.mutable_array(error_code)

try:
yield
finally:
_error_code_ref = old_error_code_ref

def set_error_if(pred: jax.Array, msg: str) -> None:
"""Set error if pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
In auto mode, this function does not work under jit.
"""
if _error_code_ref is None:
_initialize_error_code_ref()
Expand All @@ -69,7 +120,22 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
new_error_code = len(_error_list)
_error_list.append((msg, traceback))

pred = pred.any()
if isinstance(_error_code_ref.sharding, SingleDeviceSharding): # pytype: disable=attribute-error
pred = pred.any()
else:
if _error_code_ref.sharding.mesh != pred.sharding.mesh: # pytype: disable=attribute-error
raise ValueError(
"The error code state and the predicate must be on the same mesh. "
"Please use `with error_checking_context()` to redefine the error "
"code state based on the mesh."
)
pred = shard_map(
partial(jnp.any, keepdims=True),
mesh=_error_code_ref.sharding.mesh, # pytype: disable=attribute-error
in_specs=pred.sharding.spec, # pytype: disable=attribute-error
out_specs=_error_code_ref.sharding.spec, # pytype: disable=attribute-error
)(pred)

error_code = _error_code_ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
Expand All @@ -78,11 +144,15 @@ def set_error_if(pred: jax.Array, msg: str) -> None:


def raise_if_error() -> None:
"""Raise error if an error is set."""
"""Raise error if an error is set.
This function should be called after the computation is finished. It should
be used outside jit.
"""
if _error_code_ref is None: # if not initialized, do nothing
return

error_code = _error_code_ref[...]
error_code = _error_code_ref[...].min() # perform per-device reduction
if error_code == jnp.uint32(_NO_ERROR):
return
try:
Expand All @@ -92,4 +162,8 @@ def raise_if_error() -> None:
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)
finally:
_error_code_ref[...] = jnp.uint32(_NO_ERROR)
_error_code_ref[...] = jnp.full(
_error_code_ref.shape,
jnp.uint32(_NO_ERROR),
device=_error_code_ref.sharding, # pytype: disable=attribute-error
)
2 changes: 2 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "error_check_test",
srcs = ["error_check_test.py"],
enable_backends = ["tpu"],
enable_configs = ["tpu_v3_2x2"],
)

jax_multiplatform_test(
Expand Down
39 changes: 38 additions & 1 deletion tests/error_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from jax._src import config
from jax._src import error_check
from jax._src import test_util as jtu
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
import jax.numpy as jnp


JaxValueError = error_check.JaxValueError


config.parse_flags_with_absl()
jtu.request_cpu_devices(4)


@jtu.with_config(jax_check_tracer_leaks=True)
Expand Down Expand Up @@ -148,14 +150,18 @@ def body(x):
with self.assertRaisesRegex(JaxValueError, "x must be less than 10"):
error_check.raise_if_error()

def test_error_check_works_with_scan(self):
@parameterized.product(jit=[True, False])
def test_error_check_works_with_scan(self, jit):
def f(carry, x):
error_check.set_error_if(x >= 4, "x must be less than 4")
return carry + x, x + 1

def body(init, xs):
return jax.lax.scan(f, init=init, xs=xs)

if jit:
body = jax.jit(body)

init = jnp.int32(0)
xs = jnp.arange(5, dtype=jnp.int32)
_ = body(init, xs)
Expand All @@ -166,5 +172,36 @@ def body(init, xs):
_ = body(init, xs)
error_check.raise_if_error() # should not raise error

def test_error_checking_context(self):
if jax.device_count() != 4:
self.skipTest("test requires 4 devices")

# this test does not need to work under jit
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1

x = jnp.full((4, 4), -1, dtype=jnp.int32)
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()

mesh = jax.make_mesh((2, 2), ("x", "y"))
sharding = NamedSharding(mesh, P("x", "y"))

with jax.sharding.use_mesh(mesh):
with error_check.error_checking_context():
with mesh:
y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
f(y)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error() # should raise error

# The unsharded version of `f` should still work after exiting the error
# checking context.
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 20f7e5f

Please # to comment.