Skip to content

Commit

Permalink
[JAX] Allow pallas to accept scalar shape semaphores.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 727091407
  • Loading branch information
Marcello Maggioni authored and Google-ML-Automation committed Feb 15, 2025
1 parent d3850e7 commit a8265ec
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 1 deletion.
5 changes: 5 additions & 0 deletions jax/_src/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ py_library(
deps = [
"//jax",
"//jax:ad_util",
"//jax:api_util",
"//jax:config",
"//jax:core",
"//jax:dtypes",
"//jax:effects",
"//jax:mlir",
"//jax:partial_eval",
"//jax:pretty_printer",
"//jax:source_info_util",
"//jax:tree_util",
"//jax:util",
"//jax/_src/lib",
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,3 +1191,5 @@ def lower_as_mlir(
stablehlo = lowered.compiler_ir(dialect="stablehlo")

return stablehlo # type: ignore[return-value]

_out_shape_to_aval_mapping = {}
12 changes: 12 additions & 0 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AbstractMemoryRef = pallas_core.AbstractMemoryRef
no_block_spec = pallas_core.no_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
_out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping
split_list = util.split_list

_ENABLE_RUNTIME_ASSERT = config.bool_state(
Expand Down Expand Up @@ -278,3 +279,14 @@ def _tensorcore_mesh_discharge_rule(
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
_tensorcore_mesh_discharge_rule
)


def _convert_semaphore_type_to_aval(
out_shape: SemaphoreType,
) -> jax_core.AbstractValue:
return out_shape.get_array_aval()


pallas_core._out_shape_to_aval_mapping[SemaphoreType] = (
_convert_semaphore_type_to_aval
)
6 changes: 5 additions & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import helpers as pallas_helpers
from jax._src.pallas import hlo_interpreter
from jax._src.pallas import primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.util import (
Expand Down Expand Up @@ -1337,6 +1337,10 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
case pallas_core.MemoryRef():
return out_shape.get_array_aval()
case _:
if type(out_shape) in pallas_core._out_shape_to_aval_mapping:
return pallas_core._out_shape_to_aval_mapping[type(out_shape)](
out_shape
)
if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")):
raise ValueError(f"Invalid out_shape type: {type(out_shape)}")
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
Expand Down
49 changes: 49 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,55 @@ def body(sem):
)(x)
np.testing.assert_array_equal(y, x)

def test_output_dma_semaphore_ref(self):
if self.INTERPRET:
self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.')

def kernel(x_hbm_ref, y_hbm_ref, sem_out):
pltpu.make_async_copy(
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_out
).start()

def kernel2(x_hbm_ref, y_hbm_ref, sem_in, y_hbm_out):
del y_hbm_out
pltpu.make_async_copy(
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_in
).wait()

x = jnp.arange(8 * 128.0).reshape((8, 128))

@jax.jit
def body(x):
y, sem_out = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=[
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.float32),
pltpu.SemaphoreType.DMA,
],
)(x)

y = self.pallas_call(
kernel2,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
input_output_aliases={1: 0},
)(x, y, sem_out)
return y

np.testing.assert_array_equal(body(x), x)

def test_hbm_hbm_grid_dma(self):
# When using the grid, we have to emit Mosaic window_params. Test that they
# work correctly with ANY memory space operands.
Expand Down

0 comments on commit a8265ec

Please # to comment.