-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_nan_playground.py
46 lines (36 loc) · 1.16 KB
/
jax_nan_playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import jax
import jax.numpy as jnp
import haiku as hk
def f(x):
weight = jnp.ones_like(x).at[0].set(-jnp.array(float("inf")))
valid_samples = jnp.isfinite(weight)
inner_term = x + weight
inner_term = jnp.where(valid_samples, inner_term, -jnp.ones_like(inner_term) * float("inf"))
y = jax.nn.logsumexp(inner_term, axis=0)
return y
def g(x):
weight = - jnp.ones_like(x) * jnp.array(float("inf"))
valid_samples = jnp.isfinite(weight)
inner_term = x + weight
inner_term = jnp.where(valid_samples, inner_term, -jnp.ones_like(inner_term) * float("inf"))
y = jax.nn.logsumexp(inner_term, axis=0)
return y
def h(params):
x = forward.apply(params)
loss = f(x)
return loss
@hk.without_apply_rng
@hk.transform
def forward():
x = jnp.ones(3)
y = hk.nets.MLP([3, 3], b_init=hk.initializers.RandomNormal())(x)
return y
if __name__ == '__main__':
val, grad = jax.value_and_grad(f)(jnp.ones(3))
print(val, grad)
val, grad = jax.value_and_grad(g)(jnp.ones(3))
print(val, grad)
params = forward.init(jax.random.PRNGKey(0))
val, grad = jax.value_and_grad(h)(params)
print(grad)
# print(params)