forked from juho-lee/set_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mixture_of_mvns.py
executable file
·61 lines (50 loc) · 1.87 KB
/
mixture_of_mvns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch.distributions import (Dirichlet, Categorical)
from plots import scatter_mog
import matplotlib.pyplot as plt
class MultivariateNormal(object):
def __init__(self, dim):
self.dim = dim
def sample(self, B, K, labels):
raise NotImplementedError
def log_prob(self, X, params):
raise NotImplementedError
def stats(self):
raise NotImplementedError
def parse(self, raw):
raise NotImplementedError
class MixtureOfMVNs(object):
def __init__(self, mvn):
self.mvn = mvn
def sample(self, B, N, K, return_gt=False):
device = 'cpu' if not torch.cuda.is_available() \
else torch.cuda.current_device()
pi = Dirichlet(torch.ones(K)).sample(torch.Size([B])).to(device)
labels = Categorical(probs=pi).sample(torch.Size([N])).to(device)
labels = labels.transpose(0,1).contiguous()
X, params = self.mvn.sample(B, K, labels)
if return_gt:
return X, labels, pi, params
else:
return X
def log_prob(self, X, pi, params, return_labels=False):
ll = self.mvn.log_prob(X, params)
ll = ll + (pi + 1e-10).log().unsqueeze(-2)
if return_labels:
labels = ll.argmax(-1)
return ll.logsumexp(-1).mean(), labels
else:
return ll.logsumexp(-1).mean()
def plot(self, X, labels, params, axes):
mu, cov = self.mvn.stats(params)
for i, ax in enumerate(axes.flatten()):
scatter_mog(X[i].cpu().data.numpy(),
labels[i].cpu().data.numpy(),
mu[i].cpu().data.numpy(),
cov[i].cpu().data.numpy(),
ax=ax)
ax.set_xticks([])
ax.set_yticks([])
plt.subplots_adjust(hspace=0.1, wspace=0.1)
def parse(self, raw):
return self.mvn.parse(raw)