-
Notifications
You must be signed in to change notification settings - Fork 7
/
distribution.py
67 lines (52 loc) · 2.36 KB
/
distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import utils
class Bernoulli():
def __init__(self, mu):
self.mu = mu
def log_probability(self, x):
self.mu = torch.clamp(self.mu, min=1e-5, max=1.0 - 1e-5)
return (x * torch.log(self.mu) + (1.0 - x) * torch.log(1 - self.mu)).sum(1)
def sample(self):
return (torch.rand_like(self.mu).to(device=self.mu.device) < self.mu).to(torch.float)
class DiagonalGaussian():
def __init__(self, mu, logvar):
self.mu = mu
self.logvar = logvar
def log_probability(self, x):
return -0.5 * torch.sum(np.log(2.0*np.pi) + self.logvar + ((x - self.mu)**2)
/ torch.exp(self.logvar), dim=1)
def sample(self):
eps = torch.randn_like(self.mu)
return self.mu + torch.exp(0.5 * self.logvar) * eps
def repeat(self, n):
mu = self.mu.unsqueeze(1).repeat(1, n, 1).view(-1, self.mu.shape[-1])
logvar = self.logvar.unsqueeze(1).repeat(1, n, 1).view(-1, self.logvar.shape[-1])
return DiagonalGaussian(mu, logvar)
@staticmethod
def kl_div(p, q):
return 0.5 * torch.sum(q.logvar - p.logvar - 1.0 + (torch.exp(p.logvar) + (p.mu - q.mu)**2)/(torch.exp(q.logvar)), dim=1)
class Gaussian():
def __init__(self, mu, precision):
# mu: [batch_size, z_dim]
self.mu = mu
# precision: [batch_size, z_dim, z_dim]
self.precision = precision
# TODO: get rid of the inverse for efficiency
self.L = torch.cholesky(torch.inverse(precision))
self.dim = self.mu.shape[1]
def log_probability(self, x):
indices = np.arange(self.L.shape[-1])
return -0.5 * (self.dim * np.log(2.0*np.pi)
+ 2.0 * torch.log(self.L[:, indices, indices]).sum(1)
+ torch.matmul(torch.matmul((x - self.mu).unsqueeze(1), self.precision),
(x - self.mu).unsqueeze(-1)).sum([1, 2]))
def sample(self):
eps = torch.randn_like(self.mu)
return self.mu + torch.matmul(self.L, eps.unsqueeze(-1)).squeeze(-1)
def repeat(self, n):
mu = self.mu.unsqueeze(1).repeat(1, n, 1).view(-1, self.mu.shape[-1])
precision = self.precision.unsqueeze(1).repeat(1, n, 1, 1).view(-1, *self.precision.shape[1:])
return Gaussian(mu, precision)