-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_training.py
executable file
·45 lines (37 loc) · 1.59 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
from causalprob import CausalProb
from inference.training import train
from tools.structures import unpack
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import time
if __name__ == '__main__':
dim = 1
n_obs = 1000
n_samples = 100000
from models.linear_confounder_model import define_model
model = define_model(dim=dim)
true_theta = {'V1->X': jnp.array([1.]), 'X->Y': jnp.array([2.]), 'V1->Y': jnp.array([3.])}
cp = CausalProb(model=model)
u, v = cp.fill({k: u(n_obs, true_theta) for k, u in cp.draw_u.items()}, {}, true_theta, cp.draw_u.keys())
x = v['X']
o = {'V1': v['V1']}
y = v['Y']
theta0 = unpack(jnp.array(np.random.normal(size=3 * dim)), true_theta)
start_time = time.time()
print('Train model parameters...')
theta, losses = train(model, x, y, o, theta0, loss_type='neg-avg-log-evidence')
print('Training completed in {} seconds.'.format(np.round(time.time() - start_time, 2)))
print('optimal theta2 parameters:', theta)
plt.plot(losses)
plt.show()
_, est_v = cp.fill({k: u(n_samples, theta) for k, u in cp.draw_u.items()}, {}, theta, cp.draw_u.keys())
plt.figure(figsize=(20, 4))
for i, rv, in enumerate(u):
plt.subplot(1, 3, i + 1)
plt.title(rv, fontsize=15)
plt.hist(v[rv].squeeze(1), bins=int(np.sqrt(2 * n_obs)), alpha=0.5, density=True)[-1]
plt.hist(est_v[rv].squeeze(1), bins=int(np.sqrt(2 * n_samples)), alpha=0.5, density=True)[-1]
plt.legend(['true distribution', 'estimated distribution'], fontsize=12)
plt.show()