Skip to content

Commit

Permalink
Added or fixed example READMEs
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 3, 2024
1 parent 0714f74 commit 427ded0
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
3 changes: 3 additions & 0 deletions quax/examples/lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Low-rank adaptation

See https://docs.kidger.site/quax/api/lora/ for the official documentation for this library. (Which is the only "officially supported" example amongst all of those listed in this directory.)
4 changes: 2 additions & 2 deletions quax/examples/named/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ These are arrays with named dimensions. We can then use these to specify which d
```python
import equinox as eqx
import jax.random as jr
import named
import quax
import quax.examples.named as named

# Existing program
linear = eqx.nn.Linear(3, 4, key=jr.PRNGKey(0))
Expand All @@ -24,7 +24,7 @@ vector = named.NamedArray(jr.normal(jr.PRNGKey(1), (3,)), (In,))
# Wrap function (here using matrix-vector multiplication) with quaxify. Output will be
# a NamedArray!
out = quax.quaxify(named_linear)(vector)
print(out) # NamedArray(array=f32[4], axes=(Axis(name='Out', size=4),))
print(out) # NamedArray(array=f32[4], axes=(Axis(size=4),))
```

## API
Expand Down
7 changes: 4 additions & 3 deletions quax/examples/prng/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ JAX has a built in `jax.random.key` for creating PRNG keys. Here we demonstrate
```python
import jax.lax as lax
import jax.numpy as jnp
import prng
import quax
import quax.examples.prng as prng

# `key` is a PyTree wrapping a u32[2] array.
key = prng.ThreeFry(0)
prng.normal(key)

# Some primitives (lax.add_p) are disallowed.
key + 1 # TypeError!
quax.quaxify(lax.add)(key, 1) # Still a TypeError!
def f(x, y):
return x + y
quax.quaxify(f)(key, 1) # TypeError!

# Some primitives (lax.select_n) are allowed.
# We're calling `jnp.where(..., pytree1, pytree2)` -- on pytrees, not arrays!
Expand Down
31 changes: 31 additions & 0 deletions quax/examples/zero/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Symbolic zeros

This example library allows for the creation of symbolic zeros. These are equivalent to an array of zeros, like those created by `z = jnp.zeros(shape, dtype)`, except that we are able to resolve many operations, like `z * 5`, or `jnp.concatenate([z, z])`, at trace time -- as in these cases the result is again known to be a symbolic zero! -- and so we do not need to wait until runtime or hope that the compiler will figure it out.

As such, these are a more powerful version of the symbolic zeros JAX already uses inside its autodifferentiation rules (to skip computing gradients where none are required). However in this example library, we can use them anywhere in JAX.

## API

```python
zero.Zero
```

## Example

In this example, `quax.examples.zero` correctly identifies that (a) slicing an array of zeros again produces an array of zeros, and (b) that multiplying zero against nonzero still returns zero.

Thus the return value is again a symbolic zero.

```python
import jax.numpy as jnp
import quax
import quax.examples.zero as zero

z = zero.Zero((3, 4), jnp.float32) # shape and dtype

def slice_and_multiply(a, b):
return a[:, :2] * b

out = quax.quaxify(slice_and_multiply)(z, 3)
print(out) # Zero(shape=(3, 2), dtype=dtype('float32'))
```

0 comments on commit 427ded0

Please # to comment.