Implement logccdf helper for numerically stable log survival function#7996
Implement logccdf helper for numerically stable log survival function#7996ricardoV94 merged 24 commits intopymc-devs:mainfrom
logccdf helper for numerically stable log survival function#7996Conversation
Test that pm.Censored computes log-probabilities stably at the bounds: - Right censoring (upper bound): log(1 - CDF) when CDF ≈ 1 - Left censoring (lower bound): log(CDF) when CDF ≈ 0 Uses pm.Censored with Normal(0, 1) at ±40 standard deviations.
Add _logccdf (log complementary CDF / log survival function) support: - pymc/logprob/abstract.py: Add _logccdf singledispatch and _logccdf_helper - pymc/distributions/distribution.py: Register logccdf methods via metaclass - pymc/distributions/continuous.py: Add logccdf to Normal using stable normal_lccdf - pymc/logprob/censoring.py: Use _logccdf for right-censored distributions - pymc/logprob/binary.py: Use _logccdf for comparison operations - pymc/logprob/transforms.py: Use _logccdf_helper for monotonic transforms - pymc/logprob/basic.py: Add public logccdf() function - pymc/logprob/__init__.py: Export logccdf This fixes numerical instability when computing log-probabilities for censored Normal distributions at extreme tail values (e.g., 10+ sigma).
There was a problem hiding this comment.
Pull request overview
This PR adds a _logccdf dispatcher to provide numerically stable log survival function (log complementary CDF) computations for censored distributions and other operations that require computing log(1 - CDF). The implementation fixes numerical instability issues when evaluating at extreme tail values, such as computing the log-probability of a censored Normal(0,1) distribution at 40 standard deviations, which previously returned -inf instead of the correct value.
Key Changes
- Added
_logccdfsingledispatch function and_logccdf_helperfor numerically stable log survival function computation - Implemented
logccdfmethod for the Normal distribution using the existingnormal_lccdffunction - Updated censoring, binary comparison, and transform modules to use
_logccdfwhen available, with graceful fallback topt.log1mexp(logcdf)
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
pymc/logprob/abstract.py |
Added _logccdf dispatcher and _logccdf_helper following the same pattern as _logcdf |
pymc/distributions/distribution.py |
Added metaclass registration for logccdf methods to automatically dispatch to distribution-specific implementations |
pymc/distributions/continuous.py |
Implemented logccdf for Normal distribution using stable normal_lccdf function |
pymc/logprob/censoring.py |
Updated right-censored distribution logic to use _logccdf with fallback for numerical stability |
pymc/logprob/binary.py |
Updated comparison operations to use _logccdf for improved numerical stability |
pymc/logprob/transforms.py |
Updated monotonic transforms to use _logccdf_helper for continuous distributions |
pymc/logprob/basic.py |
Added public logccdf() function with comprehensive documentation and examples |
pymc/logprob/__init__.py |
Exported logccdf as part of the public API |
tests/logprob/test_censoring.py |
Added parametrized tests for numerical stability of censored distributions at extreme tail values |
tests/logprob/test_abstract.py |
Added comprehensive tests for _logccdf_helper, public logccdf function, and numerical stability |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tests/logprob/test_censoring.py
Outdated
|
|
||
| This test uses pm.Censored which is the high-level API for censored distributions. | ||
| """ | ||
| import pymc as pm |
ricardoV94
left a comment
There was a problem hiding this comment.
Yeah this would always be needed eventually. I left a request. Also can we use this in the Truncated, or it doesn't show up there?
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #7996 +/- ##
==========================================
+ Coverage 90.22% 91.47% +1.25%
==========================================
Files 116 116
Lines 18972 19065 +93
==========================================
+ Hits 17117 17440 +323
+ Misses 1855 1625 -230
🚀 New features to boost your workflow:
|
Centralizes the fallback logic so callers don't need to handle it. The helper now tries the stable _logccdf first and automatically falls back to log1mexp(logcdf) if not implemented.
Uses stable logccdf for computing log(1 - CDF(lower)) in truncated_logprob and truncated_logcdf instead of the potentially unstable log1mexp(logcdf).
The construct_ir_fgraph returns a single FunctionGraph, not a tuple. Extract ir_valued_rv from outputs and unpack ir_rv and ir_value from inputs.
Move 'import pymc as pm' to top of file with other imports instead of inside the test function.
Yes! → 15806c0 |
|
Working on test coverage... |
|
I'll only be able to review later, but the general picture seems great. How urgent is this to get merged? |
|
Not super urgent since we have a workaround: brendanjmeade/celeri#343 |
…tributions" This reverts commit 3edaae2. This will be safe once celeri requires a release of PyMC where pymc-devs/pymc#7996 has been merged.
The test_logccdf_numerical_stability test already covers this functionality at the public API level. The helper test was redundant since it tested the same numerical stability property through _logccdf_helper.
Verifies that distributions without a registered _logccdf method (e.g., Uniform) use the log1mexp(logcdf) fallback, while distributions with _logccdf (e.g., Normal) use their specialized implementation. The test inspects the computation graph structure rather than just numerical results to ensure the correct code path is exercised.
Increase test bounds from 10/40 to 100 sigma to future-proof against any potential improvements in naive computation methods. At 100 sigma, CDF(100) is truly indistinguishable from 1.0 in float64. Also enhances test docstrings with What/Why/How documentation.
Add detailed docstrings to logccdf tests explaining: - What: what the test verifies - Why: motivation and edge cases being tested - How: the testing methodology
Tests that pm.logccdf works when the random variable depends on transformed parameters, triggering the construct_ir_fgraph fallback path in the public logccdf function.
Tests that logccdf registration works correctly for custom Distribution subclasses using SymbolicRandomVariable with extended_signature, which exercises the params_idxs code path in DistributionMeta.
b770a67 to
36b8672
Compare
maresb
left a comment
There was a problem hiding this comment.
Ya, sorry about that. In this case I explicitly requested it to be super verbose in the test docstrings because I found it really hard to understand precisely what the tests were intended for and how they work.
Is it better with my latest changes? Anything else you'd like? Would you like me to rebase and squash the redundant commits?
|
Yeah much better. I left some more comments, about placement of tests and some that seem redundant. Mind doing one more pass? |
|
Thanks! I've tried to address all comments. |
|
Failing test seems unrelated |
_logccdf dispatcher for numerically stable log survival function in censored distributionslogccdf helper for numerically stable log survival function
ricardoV94
left a comment
There was a problem hiding this comment.
Thanks @maresb this is a great improvement
|
Thanks @ricardoV94 for the very careful review! How would you recommend I merge? Squash-merge? |
|
Yeah I squash merged |
…tributions" This reverts commit 3edaae2. This will be safe once celeri requires a release of PyMC where pymc-devs/pymc#7996 has been merged.
This contains an essential numerical stability fix: pymc-devs/pymc#7996
…on (pymc-devs#7996) * Add numerical stability test for censored distributions Test that pm.Censored computes log-probabilities stably at the bounds: - Right censoring (upper bound): log(1 - CDF) when CDF ≈ 1 - Left censoring (lower bound): log(CDF) when CDF ≈ 0 Uses pm.Censored with Normal(0, 1) at ±40 standard deviations. * Add _logccdf dispatcher for numerically stable log survival function Add _logccdf (log complementary CDF / log survival function) support: - pymc/logprob/abstract.py: Add _logccdf singledispatch and _logccdf_helper - pymc/distributions/distribution.py: Register logccdf methods via metaclass - pymc/distributions/continuous.py: Add logccdf to Normal using stable normal_lccdf - pymc/logprob/censoring.py: Use _logccdf for right-censored distributions - pymc/logprob/binary.py: Use _logccdf for comparison operations - pymc/logprob/transforms.py: Use _logccdf_helper for monotonic transforms - pymc/logprob/basic.py: Add public logccdf() function - pymc/logprob/__init__.py: Export logccdf This fixes numerical instability when computing log-probabilities for censored Normal distributions at extreme tail values (e.g., 10+ sigma). * Move try/except fallback into _logccdf_helper Centralizes the fallback logic so callers don't need to handle it. The helper now tries the stable _logccdf first and automatically falls back to log1mexp(logcdf) if not implemented. * Add _logccdf support to Truncated distribution Uses stable logccdf for computing log(1 - CDF(lower)) in truncated_logprob and truncated_logcdf instead of the potentially unstable log1mexp(logcdf). * Fix logccdf IR rewriting to match logcdf pattern The construct_ir_fgraph returns a single FunctionGraph, not a tuple. Extract ir_valued_rv from outputs and unpack ir_rv and ir_value from inputs. * Fix test import style in test_censoring.py Move 'import pymc as pm' to top of file with other imports instead of inside the test function. * Remove redundant test_logccdf_helper_numerical_stability The test_logccdf_numerical_stability test already covers this functionality at the public API level. The helper test was redundant since it tested the same numerical stability property through _logccdf_helper. * Add test for _logccdf_helper fallback to log1mexp Verifies that distributions without a registered _logccdf method (e.g., Uniform) use the log1mexp(logcdf) fallback, while distributions with _logccdf (e.g., Normal) use their specialized implementation. The test inspects the computation graph structure rather than just numerical results to ensure the correct code path is exercised. * Use ±100 sigma in numerical stability tests Increase test bounds from 10/40 to 100 sigma to future-proof against any potential improvements in naive computation methods. At 100 sigma, CDF(100) is truly indistinguishable from 1.0 in float64. Also enhances test docstrings with What/Why/How documentation. * Enhance test docstrings with What/Why/How documentation Add detailed docstrings to logccdf tests explaining: - What: what the test verifies - Why: motivation and edge cases being tested - How: the testing methodology * Add test for logccdf IR graph rewriting path Tests that pm.logccdf works when the random variable depends on transformed parameters, triggering the construct_ir_fgraph fallback path in the public logccdf function. * Add test for logccdf with SymbolicRandomVariable extended_signature Tests that logccdf registration works correctly for custom Distribution subclasses using SymbolicRandomVariable with extended_signature, which exercises the params_idxs code path in DistributionMeta. * Import log1mexp directly * Add tests for _logccdf on discrete distributions * Use _logccdf_helper also for discrete distributions * Remove verbose inline comments about numerical stability * Simplify verbose test docstrings * Simplify graph_contains_log1mexp using pytensor.graph.traversal.ancestors * Remove test_logccdf_transformed_argument (redundant pm.Model usage) * Remove _helper tests, keep only user-facing API tests * Move censored numerical stability test to tests/distributions/test_censored.py * Add logcdf tests for Erfc/Erfcx transforms * Explain the test assumption that Normal has a custom ccdf but Uniform doesn't * Move discrete transform logccdf tests to test_transforms.py
…tributions" This reverts commit 3edaae2. This will be safe once celeri requires a release of PyMC where pymc-devs/pymc#7996 has been merged.
This contains an essential numerical stability fix: pymc-devs/pymc#7996
Description
Disclosure: With assistance from
claude-4.5-opus-highvia Cursor.This PR adds a
_logccdf(log complementary CDF / log survival function) dispatcher to fix numerical instability when computing log-probabilities for censored distributions at extreme tail values.The Problem:
For right-censored distributions, the log-probability at the upper bound requires computing
log(1 - CDF). The existing implementation useslog1mexp(logcdf), which breaks down whenCDF ≈ 1(far right tail). For example, a censoredNormal(0, 1)at 40 standard deviations returns-infinstead of the correct≈ -804.6.The Solution:
Add a
_logccdfdispatcher that allows distributions to provide a numerically stable log survival function. ForNormal, this uses the existingnormal_lccdffunction (based onerfcx) which is stable across the entire domain.Changes:
pymc/logprob/abstract.py: Add_logccdfsingledispatch and_logccdf_helperpymc/distributions/distribution.py: Registerlogccdfmethods via metaclasspymc/distributions/continuous.py: AddlogccdftoNormalusing stablenormal_lccdfpymc/logprob/censoring.py: Use_logccdffor right-censored distributions when availablepymc/logprob/binary.py: Use_logccdffor comparison operationspymc/logprob/transforms.py: Use_logccdf_helperfor monotonic transformspymc/logprob/basic.py: Add publiclogccdf()functionpymc/logprob/__init__.py: ExportlogccdfRelated Issue
Checklist
Type of change