From 427ded033106c7f12427f03dbc82cfaacd63b7f5 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 3 Feb 2024 17:10:08 +0000 Subject: [PATCH] Added or fixed example READMEs --- quax/examples/lora/README.md | 3 +++ quax/examples/named/README.md | 4 ++-- quax/examples/prng/README.md | 7 ++++--- quax/examples/zero/README.md | 31 +++++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 quax/examples/lora/README.md create mode 100644 quax/examples/zero/README.md diff --git a/quax/examples/lora/README.md b/quax/examples/lora/README.md new file mode 100644 index 0000000..831604d --- /dev/null +++ b/quax/examples/lora/README.md @@ -0,0 +1,3 @@ +# Low-rank adaptation + +See https://docs.kidger.site/quax/api/lora/ for the official documentation for this library. (Which is the only "officially supported" example amongst all of those listed in this directory.) diff --git a/quax/examples/named/README.md b/quax/examples/named/README.md index faec91b..efb1844 100644 --- a/quax/examples/named/README.md +++ b/quax/examples/named/README.md @@ -7,8 +7,8 @@ These are arrays with named dimensions. We can then use these to specify which d ```python import equinox as eqx import jax.random as jr -import named import quax +import quax.examples.named as named # Existing program linear = eqx.nn.Linear(3, 4, key=jr.PRNGKey(0)) @@ -24,7 +24,7 @@ vector = named.NamedArray(jr.normal(jr.PRNGKey(1), (3,)), (In,)) # Wrap function (here using matrix-vector multiplication) with quaxify. Output will be # a NamedArray! out = quax.quaxify(named_linear)(vector) -print(out) # NamedArray(array=f32[4], axes=(Axis(name='Out', size=4),)) +print(out) # NamedArray(array=f32[4], axes=(Axis(size=4),)) ``` ## API diff --git a/quax/examples/prng/README.md b/quax/examples/prng/README.md index e0a82fb..3010ff1 100644 --- a/quax/examples/prng/README.md +++ b/quax/examples/prng/README.md @@ -7,16 +7,17 @@ JAX has a built in `jax.random.key` for creating PRNG keys. Here we demonstrate ```python import jax.lax as lax import jax.numpy as jnp -import prng import quax +import quax.examples.prng as prng # `key` is a PyTree wrapping a u32[2] array. key = prng.ThreeFry(0) prng.normal(key) # Some primitives (lax.add_p) are disallowed. -key + 1 # TypeError! -quax.quaxify(lax.add)(key, 1) # Still a TypeError! +def f(x, y): + return x + y +quax.quaxify(f)(key, 1) # TypeError! # Some primitives (lax.select_n) are allowed. # We're calling `jnp.where(..., pytree1, pytree2)` -- on pytrees, not arrays! diff --git a/quax/examples/zero/README.md b/quax/examples/zero/README.md new file mode 100644 index 0000000..16af51a --- /dev/null +++ b/quax/examples/zero/README.md @@ -0,0 +1,31 @@ +# Symbolic zeros + +This example library allows for the creation of symbolic zeros. These are equivalent to an array of zeros, like those created by `z = jnp.zeros(shape, dtype)`, except that we are able to resolve many operations, like `z * 5`, or `jnp.concatenate([z, z])`, at trace time -- as in these cases the result is again known to be a symbolic zero! -- and so we do not need to wait until runtime or hope that the compiler will figure it out. + +As such, these are a more powerful version of the symbolic zeros JAX already uses inside its autodifferentiation rules (to skip computing gradients where none are required). However in this example library, we can use them anywhere in JAX. + +## API + +```python +zero.Zero +``` + +## Example + +In this example, `quax.examples.zero` correctly identifies that (a) slicing an array of zeros again produces an array of zeros, and (b) that multiplying zero against nonzero still returns zero. + +Thus the return value is again a symbolic zero. + +```python +import jax.numpy as jnp +import quax +import quax.examples.zero as zero + +z = zero.Zero((3, 4), jnp.float32) # shape and dtype + +def slice_and_multiply(a, b): + return a[:, :2] * b + +out = quax.quaxify(slice_and_multiply)(z, 3) +print(out) # Zero(shape=(3, 2), dtype=dtype('float32')) +```