Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np

import jax
from jax import lax
from jax.experimental.sparse import BCOO
import jax.nn as nn
Expand Down Expand Up @@ -1909,18 +1910,40 @@ def variance(self):
return jnp.broadcast_to(var, self.batch_shape)

def cdf(self, value):
from tensorflow_probability.substrates.jax.math import betainc as tfp_betainc

# Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions
# X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5)
scaled = (value - self.loc) / self.scale
scaled_squared = scaled * scaled
beta_value = self.df / (self.df + scaled_squared)

float_type = (
jnp.promote_types(self.df.dtype, beta_value.dtype)
if jax.config.read("jax_enable_x64")
else jnp.dtype("float32")
)

# when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value)
# when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value)
scaled_sign_half = 0.5 * jnp.sign(scaled)
return (
0.5
+ scaled_sign_half
- 0.5 * jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value)
- 0.5
* jnp.sign(scaled)
* tfp_betainc(
0.5
* (
self.df.astype(float_type)
if isinstance(self.df, (jnp.ndarray, np.ndarray))
else self.df
),
0.5,
beta_value.astype(float_type)
if isinstance(self.df, (jnp.ndarray, np.ndarray))
else beta_value,
)
)

def icdf(self, q):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"jaxns==1.0.0",
"optax>=0.0.6",
"pyyaml", # flax dependency
"tensorflow_probability>=0.15.0",
"tensorflow_probability>=0.17.0",
],
"examples": [
"arviz",
Expand Down