@@ -1481,3 +1481,180 @@ def c_code(self, *args, **kwargs):
14811481
14821482
14831483betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1484+
1485+
1486+ class Hyp2F1 (ScalarOp ):
1487+ """
1488+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+ """
1490+
1491+ nin = 4
1492+ nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1493+
1494+ @staticmethod
1495+ def st_impl (a , b , c , z ):
1496+ return scipy .special .hyp2f1 (a , b , c , z )
1497+
1498+ def impl (self , a , b , c , z ):
1499+ return Hyp2F1 .st_impl (a , b , c , z )
1500+
1501+ def grad (self , inputs , grads ):
1502+ a , b , c , z = inputs
1503+ (gz ,) = grads
1504+ return [
1505+ gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1506+ gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1507+ gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1508+ gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1509+ ]
1510+
1511+ def c_code (self , * args , ** kwargs ):
1512+ raise NotImplementedError ()
1513+
1514+
1515+ hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1516+
1517+
1518+ class Hyp2F1Der (ScalarOp ):
1519+ """Derivatives of the Gaussian hypergeometric function :math:`2_F_1(a, b; c; z)`.
1520+
1521+ This is only implemented for one of the first three inputs.
1522+
1523+ Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1524+
1525+ """
1526+
1527+ nin = 5
1528+
1529+ def impl (self , a , b , c , z , wrt ):
1530+ def check_2f1_converges (a , b , c , z ) -> bool :
1531+ num_terms = 0
1532+ is_polynomial = False
1533+
1534+ def is_nonpositive_integer (x ):
1535+ return x <= 0 and x .is_integer ()
1536+
1537+ if is_nonpositive_integer (a ) and abs (a ) >= num_terms :
1538+ is_polynomial = True
1539+ num_terms = int (np .floor (abs (a )))
1540+ if is_nonpositive_integer (b ) and abs (b ) >= num_terms :
1541+ is_polynomial = True
1542+ num_terms = int (np .floor (abs (b )))
1543+
1544+ is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1545+
1546+ return not is_undefined and (
1547+ is_polynomial or np .abs (z ) < 1 or (np .abs (z ) == 1 and c > (a + b ))
1548+ )
1549+
1550+ def compute_grad_2f1 (a , b , c , z , wrt ):
1551+ r"""
1552+
1553+ Notes
1554+ -----
1555+ The algorithm can be derived by looking at the ratio of two successive terms in the series:
1556+
1557+ .. math::
1558+
1559+ \beta_{k+1} / \beta_{k} = A(k) / B(k) \\
1560+ \beta_{k+1} = A(k) / B(k) \beta_{k} \\
1561+ d[\beta_{k+1}] = d[A(k) / B(k)] \beta_{k} + A(k) / B(k) d[\beta_{k}]
1562+
1563+ via the product rule.
1564+
1565+ In the :math:`2_F_1`, :math:`A(k) / B(k)` corresponds to
1566+ :math:`(((a + k) (b + k) / ((c + k) (1 + k))) z` The partial
1567+ :math:`d[A(k)/B(k)]` with respect to the first three inputs can be
1568+ obtained from the ratio :math:`A(k)/B(k)`, by dropping the
1569+ respective term
1570+
1571+ .. math::
1572+
1573+ d/da[A(k) / B(k)] = A(k) / B(k) / (a + k) \\
1574+ d/db[A(k) / B(k)] = A(k) / B(k) / (b + k) \\
1575+ d/dc[A(k) / B(k)] = A(k) / B(k) (c + k)
1576+
1577+ The algorithm is implemented in the log scale, which adds the
1578+ complexity of working with absolute terms and tracking their signs.
1579+ """
1580+
1581+ wrt_a = wrt_b = False
1582+ if wrt == 0 :
1583+ wrt_a = True
1584+ elif wrt == 1 :
1585+ wrt_b = True
1586+ elif wrt != 2 :
1587+ raise ValueError (f"wrt must be 0, 1, or 2; got { wrt } " )
1588+
1589+ min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1590+ max_steps = int (1e6 )
1591+ precision = 1e-14
1592+
1593+ res = 0
1594+
1595+ if z == 0 :
1596+ return res
1597+
1598+ log_g_old = - np .inf
1599+ log_t_old = 0.0
1600+ log_t_new = 0.0
1601+ sign_z = np .sign (z )
1602+ log_z = np .log (np .abs (z ))
1603+
1604+ log_g_old_sign = 1
1605+ log_t_old_sign = 1
1606+ log_t_new_sign = 1
1607+ sign_zk = sign_z
1608+
1609+ for k in range (max_steps ):
1610+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1611+ if p == 0 :
1612+ return res
1613+ log_t_new += np .log (np .abs (p )) + log_z
1614+ log_t_new_sign = np .sign (p ) * log_t_new_sign
1615+
1616+ term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1617+ if wrt_a :
1618+ term += np .reciprocal (a + k )
1619+ elif wrt_b :
1620+ term += np .reciprocal (b + k )
1621+ else :
1622+ term -= np .reciprocal (c + k )
1623+
1624+ log_g_old = log_t_new + np .log (np .abs (term ))
1625+ log_g_old_sign = np .sign (term ) * log_t_new_sign
1626+ g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1627+ res += g_current
1628+
1629+ log_t_old = log_t_new
1630+ log_t_old_sign = log_t_new_sign
1631+ sign_zk *= sign_z
1632+
1633+ if k >= min_steps and np .abs (g_current ) <= precision :
1634+ return res
1635+
1636+ warnings .warn (
1637+ f"hyp2f1_der did not converge after { k } iterations" ,
1638+ RuntimeWarning ,
1639+ )
1640+ return np .nan
1641+
1642+ # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1643+ if not check_2f1_converges (a , b , c , z ):
1644+ warnings .warn (
1645+ f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1646+ RuntimeWarning ,
1647+ )
1648+ return np .nan
1649+
1650+ return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1651+
1652+ def __call__ (self , a , b , c , z , wrt ):
1653+ # This allows wrt to be a keyword argument
1654+ return super ().__call__ (a , b , c , z , wrt )
1655+
1656+ def c_code (self , * args , ** kwargs ):
1657+ raise NotImplementedError ()
1658+
1659+
1660+ hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
0 commit comments