From 9b979d9fb908301f3098efa26494f807cf51930c Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 11 Feb 2025 15:00:01 -0800 Subject: [PATCH] [Mosaic] Several fixes/improvements for the new TPU interpret mode. - Checks bounds for reads and writes to shared memory. - Pads kernel arguments when necessary. - Fix support for input-output aliasing. - Fix handling of vmap'ed dimensions. - Supports un-masked `pl.load` and masked or un-masked `pl.swap`. - Switch to using single integer device IDs instead of tuples. - Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` . PiperOrigin-RevId: 725784360 --- jax/_src/pallas/mosaic/interpret.py | 351 ++++++++++++++++++++-------- 1 file changed, 248 insertions(+), 103 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 33fe0ae60c88..bf1696b33c50 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -15,7 +15,7 @@ import collections from collections.abc import Iterable, Sequence import dataclasses -from functools import reduce +import functools import math import threading from typing import Any @@ -24,6 +24,7 @@ from jax import lax from jax._src import callback from jax._src import core as jax_core +from jax._src.lax.control_flow import for_loop from jax._src import linear_util as lu from jax._src.pallas.mosaic import primitives as mosaic_primitives from jax._src.pallas.mosaic import core as mosaic_core @@ -72,13 +73,11 @@ def __init__(self): self.counts = collections.defaultdict(int) def signal(self, inc, device_id): - device_id = tuple(int(x) for x in device_id) with self.cv: self.counts[device_id] += inc self.cv.notify_all() def wait(self, value, device_id): - device_id = tuple(int(x) for x in device_id) with self.cv: while self.counts[device_id] < value: self.cv.wait() @@ -120,7 +119,7 @@ def _clear_shared_memory(): _shared_memory = None def _allocate_buffer(device_id, memory_space, val): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] val = np.array(val) @@ -134,7 +133,7 @@ def _allocate_buffer(device_id, memory_space, val): return np.int16(buffer_id) def _deallocate_buffer(device_id, memory_space, buffer_id): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -144,7 +143,7 @@ def _deallocate_buffer(device_id, memory_space, buffer_id): shared_memory.mem.pop((memory_space, buffer_id, device_id), None) def _allocate_semaphores(device_id, shape): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) shape = tuple(map(int, shape)) num_semaphores = math.prod(shape) @@ -176,7 +175,7 @@ def _allocate_semaphores(device_id, shape): TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM]) def get_barrier_semaphore(device_id, collective_id): - device_id = tuple(map(int, device_id)) + del device_id collective_id = int(collective_id) # TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict @@ -195,7 +194,9 @@ def _transform_slice_or_index(slice_or_idx): return slice_or_idx else: start, size, stride = ( - slice_or_idx.start, slice_or_idx.size, slice_or_idx.stride) + int(slice_or_idx.start), + int(slice_or_idx.size), + int(slice_or_idx.stride)) return slice(start, start + size * stride, stride) def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): @@ -234,7 +235,7 @@ def _to_range(transforms) -> tuple[slice | int, ...]: return ret def get(device_id, memory_space, buffer_id, transforms): - device_id = tuple(int(x) for x in device_id) + device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -244,12 +245,21 @@ def get(device_id, memory_space, buffer_id, transforms): shared_memory = _get_shared_memory() with shared_memory.lock: - return shared_memory.mem[(memory_space, buffer_id, device_id)][ - _to_range(transforms) - ].copy() + read_range = _to_range(transforms) + buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] + ret = buffer[read_range].copy() + if transforms: + # TODO(jburnim): Instead of using NDIndexer, do the computation ourselves + # with buffer.shape and read_range? + expected_shape = transforms[-1].get_indexer_shape() + if expected_shape != ret.shape[:len(expected_shape)]: + raise ValueError( + f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' + f'reading [{read_range}] but bufer has shape {buffer.shape} .') + return ret def store(device_id, memory_space, buffer_id, transforms, val): - device_id = tuple(int(x) for x in device_id) + device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -260,15 +270,18 @@ def store(device_id, memory_space, buffer_id, transforms, val): shared_memory = _get_shared_memory() with shared_memory.lock: - if transforms: - shared_memory.mem[(memory_space, buffer_id, device_id)][ - _to_range(transforms) - ] = val - else: - shared_memory.mem[(memory_space, buffer_id, device_id)] = val - -def swap(device_id, memory_space, buffer_id, transforms, val): - device_id = tuple(int(x) for x in device_id) + buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + write_range = _to_range(transforms) + # TODO(jburnim): Better error message if this raises? + in_bounds_shape = buff[write_range].shape + if in_bounds_shape != val.shape: + raise ValueError( + f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): ' + f'writing [{write_range}] but buffer has shape {buff.shape} .') + buff[write_range] = val + +def swap(device_id, memory_space, buffer_id, transforms, val, mask): + device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -276,17 +289,42 @@ def swap(device_id, memory_space, buffer_id, transforms, val): except: raise ValueError('Advanced indexers are not supported on TPU') val = np.array(val) + mask = np.array(mask) if mask is not None else None + if mask is not None: + assert mask.shape == val.shape shared_memory = _get_shared_memory() with shared_memory.lock: - result = shared_memory.mem[(memory_space, buffer_id, device_id)][ - _to_range(transforms) - ].copy() - shared_memory.mem[(memory_space, buffer_id, device_id)][ - _to_range(transforms) - ] = val - - return np.array(result) + buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + read_write_range = _to_range(transforms) + # TODO(jburnim): Better error message if this raises? + raw_result = buff[read_write_range] + in_bounds_shape = raw_result.shape + if mask is None: + if in_bounds_shape != val.shape: + raise ValueError( + f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): ' + f'swapping [{read_write_range}] but buffer has shape {buff.shape} .') + buff[read_write_range] = val + return raw_result.copy() + + in_bounds_mask = np.full(mask.shape, True) + for i in range(len(in_bounds_shape)): + in_bounds_mask[in_bounds_shape[i]:] = False + if (~in_bounds_mask & mask).any(): + # TODO(jburnim): Include indices of out-of-bounds locations where mask + # is True. + raise ValueError( + f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): ' + f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ') + + in_bounds_idx = tuple(slice(i) for i in in_bounds_shape) + result = val.copy() + result[in_bounds_idx] = np.where( + mask[in_bounds_idx], raw_result, val[in_bounds_idx]) + buff[read_write_range] = np.where( + mask[in_bounds_idx], val[in_bounds_idx], raw_result) + return result def execute_dma(src, dst, send_sem, recv_sem): # NOTE: `src` is a list of arguments for `get` (device_id, memory_space, @@ -310,7 +348,7 @@ def execute_dma(src, dst, send_sem, recv_sem): recv_sem.signal(data_size, device_id=dst[0]) def print_memory(device_id): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) if all(d == 0 for d in device_id): shared_memory = _get_shared_memory() with shared_memory.lock: @@ -321,7 +359,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_sem, src_sem, dst_device_id): - device_id = tuple(int(x) for x in device_id) + device_id = int(device_id) src_memory_space, src_id = int(src_memory_space), int(src_id) src_transforms = jax.tree.map(int, src_transforms) dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) @@ -330,7 +368,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, if src_sem is not None: src_sem = int(src_sem) if dst_device_id is not None: - dst_device_id = tuple(int(x) for x in dst_device_id) + dst_device_id = int(dst_device_id) else: dst_device_id = device_id @@ -349,7 +387,7 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_sem) def dma_wait(device_id, sem, size): - device_id = tuple(int(x) for x in device_id) + device_id = int(device_id) sem = int(sem) size = int(size) @@ -359,13 +397,16 @@ def dma_wait(device_id, sem, size): sem.wait(size, device_id) def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) sem = int(sem) inc = int(inc) - target_device_id = tuple(map(int, target_device_id)) + if target_device_id is None: + target_device_id = device_id + else: + target_device_id = int(target_device_id) if target_core_index is not None: - raise NotImplementedError() + raise NotImplementedError('semaphore_signal with target_core_index') shared_memory = _get_shared_memory() with shared_memory.lock: @@ -373,7 +414,7 @@ def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): sem.signal(inc, target_device_id) def semaphore_wait(device_id, sem, value): - device_id = tuple(map(int, device_id)) + device_id = int(device_id) sem = int(sem) value = int(value) @@ -390,6 +431,26 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms): dtype = transform.transform_dtype(dtype) return shape, dtype +def _device_coords_to_logical_id(device_coords, axis_sizes): + if not isinstance(device_coords, tuple): + device_coords = (device_coords,) + assert len(device_coords) == len(axis_sizes) + sizes = list(axis_sizes.values()) + ret = 0 + for i in range(len(device_coords)): + ret += device_coords[i] * math.prod(sizes[i+1:]) + return ret + +def _device_id_to_logical(device_id, device_id_type, axis_sizes): + if device_id is None: + return None + if device_id_type == mosaic_primitives.DeviceIdType.MESH: + return _device_coords_to_logical_id(device_id, axis_sizes) + elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: + return device_id + else: + raise ValueError(f'Unsupported device ID type: {device_id_type}') + @lu.cache def _to_jaxpr(flat_fun, in_avals): new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) @@ -414,10 +475,11 @@ def write(var, value): jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) - # Get the mesh coordinates. - device_coords = tuple( - lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes) - # TODO(jburnim): Convert to a single integer device ID. + # Get the device ID. + axis_sizes = jax_core.get_axis_env().axis_sizes + device_id = _device_coords_to_logical_id( + tuple(lax.axis_index(s) for s in axis_sizes.keys()), + axis_sizes) # TODO(jburnim): Pass the device ID around, instead of re-fetching/computing # it for each sub-jaxpr. @@ -431,10 +493,32 @@ def write(var, value): invals = jax.util.safe_map(read, eqn.invars) if prim is primitives.load_p: - raise NotImplementedError() + (ref, transforms, mask, _) = jax.tree.unflatten( + eqn.params['args_tree'], invals) + if mask is not None: + raise NotImplementedError('masked load_p') + out = callback.io_callback( + get, + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + ordered=True) elif prim is primitives.swap_p: - raise NotImplementedError() + (ref, transforms, val, mask) = jax.tree.unflatten( + eqn.params['args_tree'], invals) + out = callback.io_callback( + swap, + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + val, + mask, + ordered=True) elif prim is lax.cond_p: def _make_branch(jaxpr): @@ -470,15 +554,18 @@ def _scan_body(c, a): compiler_params=compiler_params), init_vals) + elif prim is for_loop.for_p: + raise NotImplementedError('for_p') + elif prim is pjit.pjit_p: - pjit_jaxpr = eqn.params['jaxpr'] - def f(*args): - return _interpret_jaxpr(pjit_jaxpr.jaxpr, *pjit_jaxpr.consts, *args, + def f(*args, jaxpr): + return _interpret_jaxpr(jaxpr.jaxpr, *jaxpr.consts, *args, compiler_params=compiler_params) in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) - new_jaxpr = _to_jaxpr(lu.wrap_init(f, - debug_info=pjit_jaxpr.jaxpr.debug_info), - in_avals) + new_jaxpr = _to_jaxpr( + lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), + debug_info=eqn.params['jaxpr'].jaxpr.debug_info), + in_avals) out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) elif prim is primitives.run_scoped_p: @@ -490,14 +577,14 @@ def f(*args): allocs.append(callback.io_callback( _allocate_semaphores, jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), - device_coords, + device_id, v.aval.shape, ordered=True)) else: allocs.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], primitives.uninitialized_value(v.aval.shape, v.aval.dtype), ordered=True)) @@ -510,7 +597,7 @@ def f(*args): callback.io_callback( _deallocate_buffer, None, - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], a, ordered=True) @@ -519,7 +606,7 @@ def f(*args): # callback.io_callback( # _deallocate_semaphores, # None, - # device_coords, + # device_id, # a, # ordered=True) pass @@ -528,7 +615,7 @@ def f(*args): out = callback.io_callback( get, eqn.outvars[0].aval, - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[1:]), @@ -538,11 +625,12 @@ def f(*args): out = callback.io_callback( swap, eqn.outvars[0].aval, - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[2:]), invals[1], + None, ordered=True) elif prim is mosaic_primitives.dma_start_p: @@ -550,20 +638,22 @@ def f(*args): dst, dst_transforms, dst_sem, dst_sem_transforms, src_sem, src_sem_transforms, - device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) (orig_src_ref, _, orig_dst_ref, *_ ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) callback.io_callback( dma_start, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space], src, src_transforms, TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space], dst, dst_transforms, state_discharge.transform_array(dst_sem, dst_sem_transforms), state_discharge.transform_array(src_sem, src_sem_transforms), - device_id, + target_device_id, ordered=True) out = [] @@ -572,13 +662,13 @@ def f(*args): dst, dst_transforms, dst_sem, dst_sem_transforms, src_sem, src_sem_transforms, - device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) read_shape, read_dtype = _compute_transformed_shape_and_dtype( eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) callback.io_callback( dma_wait, (), - device_coords, + device_id, state_discharge.transform_array(dst_sem, dst_sem_transforms), math.prod(read_shape) * read_dtype.itemsize, ordered=True) @@ -588,20 +678,22 @@ def f(*args): out = callback.io_callback( get_barrier_semaphore, jax.ShapeDtypeStruct((), jnp.int16), - device_coords, + device_id, compiler_params['mosaic']['collective_id'], ordered=True) elif prim is mosaic_primitives.semaphore_signal_p: - sem, sem_transforms, inc, device_id, core_index = ( + sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], invals)) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) callback.io_callback( semaphore_signal, (), - device_coords, + device_id, state_discharge.transform_array(sem, sem_transforms), inc, - device_id, + target_device_id, core_index, ordered=True) out = [] @@ -612,14 +704,21 @@ def f(*args): callback.io_callback( semaphore_wait, (), - device_coords, + device_id, state_discharge.transform_array(sem, sem_transforms), value, ordered=True) out = [] + elif prim is primitives.atomic_rmw_p: + raise NotImplementedError('atomic_rmw_p') + + elif prim is primitives.atomic_cas_p: + raise NotImplementedError('atomic_cas_p') + else: - out = prim.bind(*invals, **eqn.params) + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + out = prim.bind(*subfuns, *invals, **bind_params) out = out if prim.multiple_results else [out] jax.util.safe_map(write, eqn.outvars, out) @@ -668,6 +767,29 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): dtype=np.bool_)]) return lax.squeeze(output, squeeze_dims) +def _pad_to_block_dimension(value, block_shape): + """Pads values so the shape evenly divides into block dimensions. + + For example, if values has a shape of (33, 2, 5) with a block_shape of + (32, 2, 4), this function will pad the value of shape to (64, 2, 8). + + Args: + value: Array to be padded. + block_shape: Block shapes to use for padding. If None, no padding will + be performed. + + Returns: + A padded array. + """ + padded_shape = tuple( + ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) + ) + if padded_shape != value.shape: + pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) + pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) + value = jnp.pad(value, pad_width, constant_values=pad_value) + return value + def get_interpret_effects(): return {callback._OrderedIOEffect} @@ -692,66 +814,93 @@ def interpret_pallas_call( # TODO(jburnim): Support dynamic grid sizes? grid = grid_mapping.static_grid - device_coords = tuple( - lax.axis_index(s) for s in jax_core.get_axis_env().axis_sizes) + axis_sizes = jax_core.get_axis_env().axis_sizes + device_id = _device_coords_to_logical_id( + tuple(lax.axis_index(s) for s in axis_sizes.keys()), + axis_sizes) + + # Pad input arguments. + is_indexing_dim = [ + tuple(b is pallas_core.mapped for b in bm.block_shape) + for bm in grid_mapping.block_mappings + ] + block_shapes = [ + tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + ] + num_inputs = grid_mapping.num_inputs + input_args = [ + _pad_to_block_dimension(a, bs) + for a, bs in zip(input_args, block_shapes[:num_inputs]) + ] # Allocate buffers in HBM for outputs. - io_alias_map = dict(input_output_aliases) output_buffer_ids = [] + output_buffer_shapes = [] output_vals = _initialize_output_vals( grid_mapping.block_mappings_output, args, input_output_aliases) - for out_val in output_vals: + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for out_val, bs in zip(output_vals, output_block_shapes): + padded_val = _pad_to_block_dimension(out_val, bs) + output_buffer_shapes.append(padded_val.shape) output_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - out_val, + padded_val, ordered=True)) # Allocate buffers for all kernel arguments (e.g., scalars, inputs, # outputs, scratch). + io_alias_map = dict(input_output_aliases) + oi_alias_map = {v: k for k, v in input_output_aliases} kernel_buffer_ids = [] for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], val, ordered=True)) for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs + is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: kernel_buffer_ids.append(callback.io_callback( _allocate_semaphores, jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), - device_coords, + device_id, var.aval.shape, ordered=True)) elif is_output and _is_any(var.aval.memory_space): - # Don't allocate a buffer -- use the already-allocated HBM output buffer. + # Use the already-allocated HBM output buffer. # # TODO(jburnim): For kernel args in HBM, check that block shape is the # same as for the corresponding pallas_call input, and that the index_map # is trivial. kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif (i < grid_mapping.num_inputs) and (i in io_alias_map): - # Instead of allocating a new buffer, use the already-allocated - # HBM output buffer. - assert _is_any(var.aval.memory_space) + elif is_output and (output_idx in oi_alias_map): + # Use the already-allocated (non-HBM) input buffer. + kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) + elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): + # Use the already-allocated HBM output buffer. kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) else: + # TODO(jburnim): For kernel args in HBM, check that block shape is the + # same as for the corresponding pallas_call input, and that the index_map + # is trivial. kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], primitives.uninitialized_value(var.aval.shape, var.aval.dtype), ordered=True)) - num_inputs = grid_mapping.num_inputs _, input_ids, kernel_output_ids, _ = split_list( kernel_buffer_ids, [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) @@ -769,27 +918,19 @@ def interpret_pallas_call( callback.io_callback( store, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], buffer_id, (), val, ordered=True) - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) - for bm in grid_mapping.block_mappings - ] - block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) - ] scalar_ids, in_out_ids, scratch_ids = split_list( kernel_buffer_ids, [grid_mapping.num_index_operands, len(grid_mapping.block_mappings)]) if grid: - num_iterations = reduce(jnp.multiply, grid) # type: ignore[arg-type] + num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] else: # Base case is always one iteration when grid is () num_iterations = 1 @@ -824,7 +965,7 @@ def body(carry): callback.io_callback( store, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], input_ids[j], (), @@ -845,21 +986,22 @@ def body(carry): kernel_output_val = callback.io_callback( get, var.aval, - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], kernel_output_ids[j], (), ordered=True) transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) - for st, sz in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j])), + indices=tuple(indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip(start_indices[num_inputs + j], + block_shapes[num_inputs + j], + is_indexing_dim[num_inputs + j])), shape=output_vals[j].shape, int_indexer_shape=()) callback.io_callback( store, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], output_buffer_ids[j], (transform,), @@ -880,19 +1022,22 @@ def body(carry): callback.io_callback( get, val, - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], output_buffer_id, - (), + (indexing.NDIndexer.from_indices_shape( + tuple(indexing.ds(0, s) for s in val.shape), + output_buffer_shape),), ordered=True) - for val, output_buffer_id in zip(output_vals, output_buffer_ids) + for val, output_buffer_id, output_buffer_shape in zip( + output_vals, output_buffer_ids, output_buffer_shapes) ] for buffer_id in output_buffer_ids: callback.io_callback( _deallocate_buffer, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], buffer_id, ordered=True) @@ -903,7 +1048,7 @@ def body(carry): callback.io_callback( _deallocate_buffer, (), - device_coords, + device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], buffer_id, ordered=True)