Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d86e2e3

Browse files
ColtAllenbrandonwillard
authored andcommitted
Add Hyp2F1 Op and poch, factorial functions
1 parent 00c8b41 commit d86e2e3

6 files changed

Lines changed: 440 additions & 3 deletions

File tree

aesara/scalar/math.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,180 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
"""
1490+
1491+
nin = 4
1492+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1493+
1494+
@staticmethod
1495+
def st_impl(a, b, c, z):
1496+
return scipy.special.hyp2f1(a, b, c, z)
1497+
1498+
def impl(self, a, b, c, z):
1499+
return Hyp2F1.st_impl(a, b, c, z)
1500+
1501+
def grad(self, inputs, grads):
1502+
a, b, c, z = inputs
1503+
(gz,) = grads
1504+
return [
1505+
gz * hyp2f1_der(a, b, c, z, wrt=0),
1506+
gz * hyp2f1_der(a, b, c, z, wrt=1),
1507+
gz * hyp2f1_der(a, b, c, z, wrt=2),
1508+
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1509+
]
1510+
1511+
def c_code(self, *args, **kwargs):
1512+
raise NotImplementedError()
1513+
1514+
1515+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1516+
1517+
1518+
class Hyp2F1Der(ScalarOp):
1519+
"""Derivatives of the Gaussian hypergeometric function :math:`2_F_1(a, b; c; z)`.
1520+
1521+
This is only implemented for one of the first three inputs.
1522+
1523+
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1524+
1525+
"""
1526+
1527+
nin = 5
1528+
1529+
def impl(self, a, b, c, z, wrt):
1530+
def check_2f1_converges(a, b, c, z) -> bool:
1531+
num_terms = 0
1532+
is_polynomial = False
1533+
1534+
def is_nonpositive_integer(x):
1535+
return x <= 0 and x.is_integer()
1536+
1537+
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1538+
is_polynomial = True
1539+
num_terms = int(np.floor(abs(a)))
1540+
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1541+
is_polynomial = True
1542+
num_terms = int(np.floor(abs(b)))
1543+
1544+
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1545+
1546+
return not is_undefined and (
1547+
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1548+
)
1549+
1550+
def compute_grad_2f1(a, b, c, z, wrt):
1551+
r"""
1552+
1553+
Notes
1554+
-----
1555+
The algorithm can be derived by looking at the ratio of two successive terms in the series:
1556+
1557+
.. math::
1558+
1559+
\beta_{k+1} / \beta_{k} = A(k) / B(k) \\
1560+
\beta_{k+1} = A(k) / B(k) \beta_{k} \\
1561+
d[\beta_{k+1}] = d[A(k) / B(k)] \beta_{k} + A(k) / B(k) d[\beta_{k}]
1562+
1563+
via the product rule.
1564+
1565+
In the :math:`2_F_1`, :math:`A(k) / B(k)` corresponds to
1566+
:math:`(((a + k) (b + k) / ((c + k) (1 + k))) z` The partial
1567+
:math:`d[A(k)/B(k)]` with respect to the first three inputs can be
1568+
obtained from the ratio :math:`A(k)/B(k)`, by dropping the
1569+
respective term
1570+
1571+
.. math::
1572+
1573+
d/da[A(k) / B(k)] = A(k) / B(k) / (a + k) \\
1574+
d/db[A(k) / B(k)] = A(k) / B(k) / (b + k) \\
1575+
d/dc[A(k) / B(k)] = A(k) / B(k) (c + k)
1576+
1577+
The algorithm is implemented in the log scale, which adds the
1578+
complexity of working with absolute terms and tracking their signs.
1579+
"""
1580+
1581+
wrt_a = wrt_b = False
1582+
if wrt == 0:
1583+
wrt_a = True
1584+
elif wrt == 1:
1585+
wrt_b = True
1586+
elif wrt != 2:
1587+
raise ValueError(f"wrt must be 0, 1, or 2; got {wrt}")
1588+
1589+
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1590+
max_steps = int(1e6)
1591+
precision = 1e-14
1592+
1593+
res = 0
1594+
1595+
if z == 0:
1596+
return res
1597+
1598+
log_g_old = -np.inf
1599+
log_t_old = 0.0
1600+
log_t_new = 0.0
1601+
sign_z = np.sign(z)
1602+
log_z = np.log(np.abs(z))
1603+
1604+
log_g_old_sign = 1
1605+
log_t_old_sign = 1
1606+
log_t_new_sign = 1
1607+
sign_zk = sign_z
1608+
1609+
for k in range(max_steps):
1610+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1611+
if p == 0:
1612+
return res
1613+
log_t_new += np.log(np.abs(p)) + log_z
1614+
log_t_new_sign = np.sign(p) * log_t_new_sign
1615+
1616+
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1617+
if wrt_a:
1618+
term += np.reciprocal(a + k)
1619+
elif wrt_b:
1620+
term += np.reciprocal(b + k)
1621+
else:
1622+
term -= np.reciprocal(c + k)
1623+
1624+
log_g_old = log_t_new + np.log(np.abs(term))
1625+
log_g_old_sign = np.sign(term) * log_t_new_sign
1626+
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1627+
res += g_current
1628+
1629+
log_t_old = log_t_new
1630+
log_t_old_sign = log_t_new_sign
1631+
sign_zk *= sign_z
1632+
1633+
if k >= min_steps and np.abs(g_current) <= precision:
1634+
return res
1635+
1636+
warnings.warn(
1637+
f"hyp2f1_der did not converge after {k} iterations",
1638+
RuntimeWarning,
1639+
)
1640+
return np.nan
1641+
1642+
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1643+
if not check_2f1_converges(a, b, c, z):
1644+
warnings.warn(
1645+
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1646+
RuntimeWarning,
1647+
)
1648+
return np.nan
1649+
1650+
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1651+
1652+
def __call__(self, a, b, c, z, wrt):
1653+
# This allows wrt to be a keyword argument
1654+
return super().__call__(a, b, c, z, wrt)
1655+
1656+
def c_code(self, *args, **kwargs):
1657+
raise NotImplementedError()
1658+
1659+
1660+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

aesara/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

aesara/tensor/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,16 @@ def gammal(k, x):
13861386
"""Lower incomplete gamma function."""
13871387

13881388

1389+
@scalar_elemwise
1390+
def hyp2f1(a, b, c, z):
1391+
"""Gaussian hypergeometric function."""
1392+
1393+
1394+
@scalar_elemwise
1395+
def hyp2f1_der(a, b, c, z):
1396+
"""Derivatives for Gaussian hypergeometric function."""
1397+
1398+
13891399
@scalar_elemwise
13901400
def j0(x):
13911401
"""Bessel function of the first kind of order 0."""
@@ -3128,6 +3138,8 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31283138
"power",
31293139
"logaddexp",
31303140
"logsumexp",
3141+
"hyp2f1",
3142+
"hyp2f1_der",
31313143
]
31323144

31333145
DEPRECATED_NAMES = [

aesara/tensor/special.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aesara.graph.basic import Apply
88
from aesara.link.c.op import COp
99
from aesara.tensor.basic import as_tensor_variable
10-
from aesara.tensor.math import neg, sum
10+
from aesara.tensor.math import gamma, neg, sum
1111

1212

1313
class SoftmaxGrad(COp):
@@ -768,7 +768,19 @@ def log_softmax(c, axis=UNSET_AXIS):
768768
return LogSoftmax(axis=axis)(c)
769769

770770

771+
def poch(z, m):
772+
"""Compute the Pochhammer/rising factorial."""
773+
return gamma(z + m) / gamma(z)
774+
775+
776+
def factorial(n):
777+
"""Compute the factorial."""
778+
return gamma(n + 1)
779+
780+
771781
__all__ = [
772782
"softmax",
773783
"log_softmax",
784+
"poch",
785+
"factorial",
774786
]

0 commit comments

Comments
 (0)