-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
56 lines (42 loc) · 1.62 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
def from_pil_image(x):
"""Converts from a PIL image to a JAX array."""
x = jnp.array(x)
if x.ndim == 2:
x = x[..., None]
return x.transpose((2, 0, 1)) / 127.5 - 1
def to_pil_image(x):
"""Converts from a JAX array to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
else:
x = x.transpose((1, 2, 0))
arr = np.array(jnp.round(jnp.clip((x + 1) * 127.5, 0, 255)).astype(jnp.uint8))
return Image.fromarray(arr)
def log_snr_to_alpha_sigma(log_snr):
"""Returns the scaling factors for the clean image and for the noise, given
the log SNR for a timestep."""
return jnp.sqrt(jax.nn.sigmoid(log_snr)), jnp.sqrt(jax.nn.sigmoid(-log_snr))
def alpha_sigma_to_log_snr(alpha, sigma):
"""Returns a log snr, given the scaling factors for the clean image and for
the noise."""
return jnp.log(alpha**2 / sigma**2)
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return jnp.arctan2(sigma, alpha) / jnp.pi * 2
def get_ddpm_schedule(ddpm_t):
"""Returns timesteps for the noise schedule from the DDPM paper."""
log_snr = -jnp.log(jnp.expm1(1e-4 + 10 * ddpm_t**2))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha_sigma_to_t(alpha, sigma)