-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0714f74
commit 427ded0
Showing
4 changed files
with
40 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Low-rank adaptation | ||
|
||
See https://docs.kidger.site/quax/api/lora/ for the official documentation for this library. (Which is the only "officially supported" example amongst all of those listed in this directory.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Symbolic zeros | ||
|
||
This example library allows for the creation of symbolic zeros. These are equivalent to an array of zeros, like those created by `z = jnp.zeros(shape, dtype)`, except that we are able to resolve many operations, like `z * 5`, or `jnp.concatenate([z, z])`, at trace time -- as in these cases the result is again known to be a symbolic zero! -- and so we do not need to wait until runtime or hope that the compiler will figure it out. | ||
|
||
As such, these are a more powerful version of the symbolic zeros JAX already uses inside its autodifferentiation rules (to skip computing gradients where none are required). However in this example library, we can use them anywhere in JAX. | ||
|
||
## API | ||
|
||
```python | ||
zero.Zero | ||
``` | ||
|
||
## Example | ||
|
||
In this example, `quax.examples.zero` correctly identifies that (a) slicing an array of zeros again produces an array of zeros, and (b) that multiplying zero against nonzero still returns zero. | ||
|
||
Thus the return value is again a symbolic zero. | ||
|
||
```python | ||
import jax.numpy as jnp | ||
import quax | ||
import quax.examples.zero as zero | ||
|
||
z = zero.Zero((3, 4), jnp.float32) # shape and dtype | ||
|
||
def slice_and_multiply(a, b): | ||
return a[:, :2] * b | ||
|
||
out = quax.quaxify(slice_and_multiply)(z, 3) | ||
print(out) # Zero(shape=(3, 2), dtype=dtype('float32')) | ||
``` |