Skip to content

Commit

Permalink
Replace jax.random.KeyArray with jax.Array to suppress deprecation wa…
Browse files Browse the repository at this point in the history
…rnings. (#166)
  • Loading branch information
markblee committed Nov 4, 2023
1 parent 32d439e commit b2ccd7b
Show file tree
Hide file tree
Showing 27 changed files with 81 additions and 88 deletions.
2 changes: 1 addition & 1 deletion axlearn/common/adapter_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _create_dummy_inputs(self):
return cfg.create_dummy_input_fn(**cfg.create_dummy_input_kwargs)

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: utils.Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
if self._use_prebuilt_params(prebuilt):
return prebuilt
Expand Down
8 changes: 4 additions & 4 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def transform_factorization_spec(
)

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
if self._use_prebuilt_params(prebuilt):
return prebuilt
Expand Down Expand Up @@ -2735,7 +2735,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
self._layers.append(self._add_child(f"layer{i}", layer_cfg))

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
cfg = self.config # type: StackedTransformerLayer.Config
prng_key = split_prng_key(prng_key, cfg.num_layers)
Expand Down Expand Up @@ -3057,7 +3057,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
self._add_child("repeat", repeat_cfg)

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
# We need to call self.repeat.initialize_parameters_recursively() with the same prng_key
# to ensure initialization parity with StackedTransformerLayer.
Expand Down Expand Up @@ -3188,7 +3188,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
self._add_child("pipeline", pipeline_cfg)

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
cfg = self.config # type: PipelinedTransformerLayer.Config
# We pre-split all num_layers keys to ensure initialization parity with
Expand Down
8 changes: 4 additions & 4 deletions axlearn/common/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class RematSpec:
class ParameterNoise(Configurable):
"""An interface for applying parameter noise."""

def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
def apply(self, prng_key: Tensor, params: NestedTensor) -> NestedTensor:
"""To be implemented by subclasses."""
raise NotImplementedError(self)

Expand Down Expand Up @@ -275,7 +275,7 @@ def create_parameter_specs_recursively(self) -> NestedParameterSpec:
return specs

def initialize_parameters_recursively(
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
) -> NestedTensor:
params = {}
param_specs = self._create_layer_parameter_specs()
Expand Down Expand Up @@ -318,7 +318,7 @@ def _use_prebuilt_params(self, prebuilt: Optional[NestedTensor]) -> bool:
return True

def _initialize_parameter(
self, name: str, *, prng_key: jax.random.KeyArray, parameter_spec: ParameterSpec
self, name: str, *, prng_key: Tensor, parameter_spec: ParameterSpec
) -> Tensor:
"""Adds a parameter with the given name and shape.
Expand All @@ -345,7 +345,7 @@ def _initialize_parameter(
return param

def apply_parameter_noise_recursively(
self, prng_key: jax.random.KeyArray, params: NestedTensor
self, prng_key: Tensor, params: NestedTensor
) -> NestedTensor:
"""Applies parameter noise recursively on `params`.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class ParameterScaler(ParameterNoise):
class Config(ParameterNoise.Config):
scale: float = 1.0

def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
def apply(self, prng_key: utils.Tensor, params: NestedTensor) -> NestedTensor:
cfg = self.config
return jax.tree_util.tree_map(lambda x: x * cfg.scale, params)

Expand Down
3 changes: 1 addition & 2 deletions axlearn/common/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from typing import Optional, Tuple, Union

import jax
from jax import numpy as jnp

from axlearn.common.attention import (
Expand Down Expand Up @@ -328,7 +327,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):

def initialize_parameters_recursively(
self,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
*,
prebuilt: Optional[NestedTensor] = None,
) -> NestedTensor:
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ class DecodingState(NamedTuple):
# The current state of the autoregressive decoding caches.
cache: NestedTensor
# Random generator state.
prng_key: jax.random.KeyArray
prng_key: Tensor


def _decode_init(
Expand All @@ -739,7 +739,7 @@ def _decode_init(
num_decodes: int,
max_decode_len: int,
cache: NestedTensor,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
pad_id: int,
token_scores: Optional[Tensor] = None,
) -> DecodingState:
Expand Down Expand Up @@ -902,7 +902,7 @@ def sample_decode(
tokens_to_scores: Callable[[Tensor, NestedTensor], Tuple[Tensor, NestedTensor]],
stop_decoding_condition: StopDecodingCondition,
num_decodes: int,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
max_decode_len: Optional[int] = None,
loop: Literal["lax", "python"] = "lax",
pad_id: int = 0,
Expand Down
26 changes: 10 additions & 16 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def __init__(
if mesh.empty:
raise RuntimeError("MetricCalculator should be created within the context of a mesh")

def init_state(
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
) -> NestedTensor:
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
"""Initializes the state.
Will be called at the beginning of an evaluation step.
Expand Down Expand Up @@ -212,7 +210,7 @@ def _call_model(
self,
*,
method: str,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
model_params: NestedTensor,
input_batch: NestedTensor,
**kwargs,
Expand Down Expand Up @@ -285,9 +283,7 @@ def __init__(
self._metric_accumulator = None
self._jit_forward = self._pjit(self._forward_in_pjit)

def init_state(
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
) -> NestedTensor:
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
cfg = self.config
self._metric_accumulator = cfg.metric_accumulator.instantiate()
return dict(prng_key=prng_key)
Expand All @@ -308,7 +304,7 @@ def forward(
def _forward_in_pjit(
self,
model_params: NestedTensor,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
input_batch: NestedTensor,
) -> Dict[str, NestedTensor]:
"""Calls `self._model` and returns summaries."""
Expand Down Expand Up @@ -379,9 +375,7 @@ def __init__(
model_param_partition_specs=model_param_partition_specs,
)

def init_state(
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
) -> NestedTensor:
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
states = {}
for name, calculator in self.children.items():
states[name] = calculator.init_state(prng_key=prng_key, model_params=model_params)
Expand Down Expand Up @@ -525,11 +519,11 @@ def eval_step(
self,
step: int,
*,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
model_params: NestedTensor,
return_aux: bool = False,
train_summaries: Optional[NestedTensor] = None,
) -> Tuple[jax.random.KeyArray, Optional[Dict[str, Any]], Optional[List[NestedTensor]]]:
) -> Tuple[Tensor, Optional[Dict[str, Any]], Optional[List[NestedTensor]]]:
"""Runs eval for the given step.
Args:
Expand Down Expand Up @@ -682,7 +676,7 @@ def __init__(
self._metric_accumulator: MetricAccumulator = None

def init_state( # pylint: disable=duplicate-code
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
self, *, prng_key: Tensor, model_params: NestedTensor
) -> NestedTensor:
self._metric_accumulator = MetricAccumulator.default_config().instantiate()
return dict(prng_key=prng_key)
Expand Down Expand Up @@ -724,7 +718,7 @@ def forward(
def _predict_in_pjit(
self,
model_params: NestedTensor,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
input_batch: NestedTensor,
) -> Dict[str, NestedTensor]:
"""Core function that calls model's predict() method for each batch and will be pjit-ed."""
Expand Down Expand Up @@ -759,7 +753,7 @@ def _calculate_metrics(self, outputs: PredictionOutputs) -> Dict[str, Tensor]:
def _compute_metrics_in_pjit(
self,
model_params: NestedTensor,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
outputs: List[PredictionOutputs],
) -> Dict[str, NestedTensor]:
"""Computes metrics and returns them in "replicated"."""
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/evaler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _call_model(
self,
*,
method: str,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
model_params: NestedTensor,
input_batch: NestedTensor,
**kwargs,
Expand Down
17 changes: 9 additions & 8 deletions axlearn/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
NestedPartitionSpec,
NestedTensor,
PartitionSpec,
Tensor,
TensorSpec,
)

Expand All @@ -56,12 +57,12 @@ class MethodRunner:
def __init__(
self,
*,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
mesh: jax.sharding.Mesh,
input_batch_partition_spec: DataPartitionType,
jit_run_on_batch: Callable[
[jax.random.KeyArray, NestedTensor],
Tuple[jax.random.KeyArray, NestedTensor, NestedTensor],
[Tensor, NestedTensor],
Tuple[Tensor, NestedTensor, NestedTensor],
],
):
"""Initializes MethodRunner object.
Expand Down Expand Up @@ -141,7 +142,7 @@ def __call__(self, input_batch: NestedTensor) -> Output:
class _InferenceRunnerState(NamedTuple):
"""Contains inference runner {state | state-partition-specs}."""

prng_key: Union[jax.random.KeyArray, NestedPartitionSpec]
prng_key: Union[Tensor, NestedPartitionSpec]
model: Union[NestedTensor, NestedPartitionSpec]
learner: Optional[Union[NestedTensor, NestedPartitionSpec]] = None

Expand Down Expand Up @@ -255,7 +256,7 @@ def run(
input_batches: Iterable[NestedTensor],
*,
method: str,
prng_key: Optional[jax.random.KeyArray] = None,
prng_key: Optional[Tensor] = None,
**kwargs,
) -> Generator[NestedTensor, None, None]:
"""Runs inference on the provided input batches.
Expand Down Expand Up @@ -296,7 +297,7 @@ def create_method_runner(
self,
*,
method: str,
prng_key: Optional[jax.random.KeyArray] = None,
prng_key: Optional[Tensor] = None,
**kwargs,
) -> MethodRunner:
"""Creates MethodRunner for the specified method and arguments.
Expand Down Expand Up @@ -361,13 +362,13 @@ def inference_iter(model_params, prng_key, input_batch):

def _inference_iter(
self,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
model_params: NestedTensor,
input_batch: Dict[str, Any],
*,
method,
**kwargs,
) -> Tuple[jax.random.KeyArray, NestedTensor, NestedTensor]:
) -> Tuple[Tensor, NestedTensor, NestedTensor]:
"""Implements inference for a single input batch."""
cfg = self.config
new_prng_key, iter_key = jax.random.split(prng_key)
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def initialize(
self,
name: str,
*,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
shape: Shape,
dtype: jnp.dtype,
axes: Optional[FanAxes] = None,
Expand Down Expand Up @@ -238,7 +238,7 @@ def _build_ckpt(
root_dir: str,
mesh_shape: Tuple[int, int],
mesh_axis_names: Tuple[str, str],
prng_key: jax.random.KeyArray,
prng_key: Tensor,
use_ema: bool = False,
) -> Tuple[NestedTensor, str]:
devices = mesh_utils.create_device_mesh(mesh_shape)
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ class Config(ParameterNoise.Config):

vn_std: Required[float] = REQUIRED

def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
def apply(self, prng_key: Tensor, params: NestedTensor) -> NestedTensor:
cfg = self.config
if cfg.vn_std <= 0:
return params
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/metrics_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Config(ModelSummaryAccumulator.Config):
def _forward_in_pjit(
self,
model_params: NestedTensor,
prng_key: jax.random.KeyArray,
prng_key: Tensor,
input_batch: NestedTensor,
) -> Dict[str, NestedTensor]:
"""Calls `self._model` and returns summaries."""
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class InvocationContext: # pylint: disable=too-many-instance-attributes
# The state of the module.
state: NestedTensor
is_training: bool
prng_key: Optional[jax.random.KeyArray]
prng_key: Optional[Tensor]
output_collection: OutputCollection

def path(self):
Expand Down Expand Up @@ -670,7 +670,7 @@ def is_training(self) -> bool:
return self.get_invocation_context().is_training

@property
def prng_key(self) -> jax.random.KeyArray:
def prng_key(self) -> Tensor:
return self.get_invocation_context().prng_key

@property
Expand Down Expand Up @@ -724,7 +724,7 @@ def nullary():

def functional(
module: Module,
prng_key: Optional[jax.random.KeyArray],
prng_key: Optional[Tensor],
state: NestedTensor,
inputs: Union[Sequence[Any], Dict[str, Any]],
*,
Expand Down
Loading

0 comments on commit b2ccd7b

Please # to comment.