Skip to content

Commit

Permalink
Support optimization_level and memory_fitting_level XLA compilation o…
Browse files Browse the repository at this point in the history
…ptions.

PiperOrigin-RevId: 727070422
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 531d80d commit d3850e7
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
15 changes: 14 additions & 1 deletion jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -190,6 +191,13 @@ def get_compile_options(

build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
if xla_extension_version >= 316:
build_options.optimization_level = config.EffortLevel(
config.optimization_level.value
).value
build_options.memory_fitting_level = config.EffortLevel(
config.memory_fitting_level.value
).value

# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
Expand All @@ -203,7 +211,12 @@ def get_compile_options(
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
'exec_time_optimization_effort', 'memory_fitting_effort']
"exec_time_optimization_effort", "memory_fitting_effort"]
if xla_extension_version >= 316:
overrides_on_build_options.extend(
["optimization_level", "memory_fitting_level"]
)

env_options_overrides = dict(env_options_overrides)
for name in overrides_on_build_options:
if name in env_options_overrides:
Expand Down
55 changes: 55 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Iterator, Sequence
import contextlib
import enum
import functools
import itertools
import logging
Expand All @@ -35,6 +36,29 @@
_T = TypeVar('_T')


class EffortLevel(enum.Enum):
"""Effort level enum, mirroring the XLA effort options."""

UNKNOWN = 0
O0 = 9
O1 = 19
O2 = 29
O3 = 39

@classmethod
def _missing_(cls, value: object) -> EffortLevel | None:
return _effort_from_string.get(value)


_effort_from_string: dict[Any, EffortLevel] = {
'UNKNOWN': EffortLevel.UNKNOWN,
'O0': EffortLevel.O0,
'O1': EffortLevel.O1,
'O2': EffortLevel.O2,
'O3': EffortLevel.O3,
}


def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
Expand Down Expand Up @@ -1727,6 +1751,37 @@ def _update_garbage_collection_guard(state, key, val):
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
)

optimization_level = enum_state(
name='jax_optimization_level',
enum_values=[
'UNKNOWN',
'O0',
'O1',
'O2',
'O3',
],
default='UNKNOWN',
help='The degree to which the compiler should optimize for execution time',
include_in_jit_key=True
)

memory_fitting_level = enum_state(
name='jax_memory_fitting_level',
enum_values=[
'UNKNOWN',
'O0',
'O1',
'O2',
'O3',
],
default='UNKNOWN',
help=(
'The degree to which the compiler should attempt to make the program'
' fit in memory'
),
include_in_jit_key=True
)

cpu_collectives_implementation = optional_enum_state(
name='jax_cpu_collectives_implementation',
enum_values=["gloo", "mpi", "megascale"],
Expand Down
31 changes: 31 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
Expand Down Expand Up @@ -1366,6 +1367,36 @@ def f(x):
"exec_time_compilation_effort": 0.0,
})(1.0)

def test_optimization_level_compiler_option(self):
def f(x):
return jnp.sqrt(x**2) + 1.0

if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={
"optimization_level": config.EffortLevel.O1.value,
},
)(
1.0
) # doesn't crash.

def test_memory_fitting_level_compiler_option(self):
def f(x):
return jnp.sqrt(x**2) + 1.0

if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={
"memory_fitting_level": config.EffortLevel.O0.value,
},
)(
1.0
) # doesn't crash.

def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
Expand Down

0 comments on commit d3850e7

Please # to comment.