Skip to content

Commit 1bb42ad

Browse files
fix: eliminate RuntimeWarnings in von Mises-Fisher loss backward pass (#824)
* fix: eliminate RuntimeWarnings in von Mises-Fisher loss backward pass - Replace np.where with boolean masking to avoid double evaluation - Add comprehensive unit tests for zero handling in LogCMK.backward - Enhanced docstrings with mathematical background and implementation details - Added mathematical background documentation in docs/source/models/ Fixes division by zero warnings while maintaining numerical accuracy. Error bound analysis shows <1e-21 accuracy for |κ| < 1e-6. * Fix mathematical expressions in documentation Corrected mathematical expressions in the von Mises-Fisher documentation so they are properly rendered in Github markdown.
1 parent 631b1a6 commit 1bb42ad

File tree

4 files changed

+350
-5
lines changed

4 files changed

+350
-5
lines changed

docs/source/models/models.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,10 @@ Below is a minimal example for training a GNN in GraphNeT for energy reconstruct
513513
Because :code:`ModelConfig` summarises a :code:`Model` completely, including its :code:`Task`\ (s),
514514
the only modifications required to change the example to reconstruct (or classify) a different attribute than energy, is to pass a :code:`ModelConfig` that defines a model with the corresponding :code:`Task`.
515515
Similarly, if you wanted to train on a different :code:`Dataset`, you would just have to pass a :code:`DatasetConfig` that defines *that* :code:`Dataset` instead.
516+
517+
Mathematical Background
518+
-----------------------
519+
520+
For detailed mathematical derivations and implementation details of specific loss functions:
521+
522+
* :doc:`von_mises_fisher_mathematical_background` - Mathematical background for the von Mises-Fisher loss function implementation, including Taylor series analysis and numerical stability considerations.
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
2+
3+
# Mathematical Background: von Mises-Fisher Loss Implementation
4+
5+
## Calculation for 3D
6+
To show that for $m=3$,
7+
8+
$$- \frac{I_{m/2}(\kappa)}{I_{m/2-1}(\kappa)} = \frac{1}{\kappa}-\frac{1}{\tanh(\kappa)}$$
9+
10+
we first substitute $m=3$ into the equation. This gives us:
11+
12+
$$- \frac{I_{3/2}(\kappa)}{I_{1/2}(\kappa)} = \frac{1}{\kappa}-\frac{1}{\tanh(\kappa)}$$
13+
14+
We need to evaluate the left side of this equation, which involves **modified Bessel functions of the first kind** of half-integer order, specifically $I_{3/2}(\kappa)$ and $I_{1/2}(\kappa)$.
15+
16+
### Expressing Bessel Functions
17+
18+
The modified Bessel functions of the first kind of half-integer order can be expressed in terms of elementary hyperbolic functions.
19+
20+
For the order $1/2$, the function is:
21+
22+
$$I_{1/2}(\kappa) = \sqrt{\frac{2}{\pi\kappa}} \sinh(\kappa)$$
23+
24+
For the order $3/2$, the function is:
25+
26+
$$I_{3/2}(\kappa) = \sqrt{\frac{2}{\pi\kappa}} \left( \cosh(\kappa) - \frac{\sinh(\kappa)}{\kappa} \right)$$
27+
28+
### Calculating the Ratio
29+
30+
Now, we can compute the ratio on the left side of the equation:
31+
32+
$$- \frac{I_{3/2}(\kappa)}{I_{1/2}(\kappa)} = - \frac{\sqrt{\frac{2}{\pi\kappa}} \left( \cosh(\kappa) - \frac{\sinh(\kappa)}{\kappa} \right)}{\sqrt{\frac{2}{\pi\kappa}} \sinh(\kappa)}$$
33+
34+
The term $\sqrt{\frac{2}{\pi\kappa}}$ cancels out, leaving:
35+
36+
$$- \frac{\cosh(\kappa) - \frac{\sinh(\kappa)}{\kappa}}{\sinh(\kappa)} = - \left( \frac{\cosh(\kappa)}{\sinh(\kappa)} - \frac{1}{\kappa} \right)$$
37+
38+
Using the definition of the **hyperbolic cotangent**, $\coth(\kappa) = \frac{\cosh(\kappa)}{\sinh(\kappa)}$, we have:
39+
40+
$$- \left(\coth(\kappa) - \frac{1}{\kappa}\right) = \frac{1}{\kappa} - \coth(\kappa)$$
41+
42+
### Final Result
43+
44+
Finally, since $\coth(\kappa) = \frac{1}{\tanh(\kappa)}$, we can write:
45+
46+
$$\frac{1}{\kappa} - \frac{1}{\tanh(\kappa)}$$
47+
48+
This is exactly the right side of the initial equation. Thus, we have shown that for $m=3$:
49+
50+
$$- \frac{I_{3/2}(\kappa)}{I_{1/2}(\kappa)} = \frac{1}{\kappa}-\frac{1}{\tanh(\kappa)} \quad \blacksquare$$
51+
52+
---
53+
54+
## Taylor Series Approximation
55+
56+
To prove that the Taylor series for $f(\kappa) = \frac{1}{\kappa}-\frac{1}{\tanh(\kappa)}$ is well-defined and alternating at $\kappa = 0$, we first need to analyze the behavior of the function at the origin and then examine the structure of its series expansion.
57+
58+
### Well-Defined at $\kappa = 0$
59+
The function $f(\kappa)$ appears to be singular at $\kappa=0$ because both $\frac{1}{\kappa}$ and $\frac{1}{\tanh(\kappa)}$ diverge. To determine if the singularity is removable, we can evaluate the limit as $\kappa \to 0$.
60+
61+
First, rewrite the function as a single fraction:
62+
63+
$$f(\kappa) = \frac{1}{\kappa} - \frac{\cosh(\kappa)}{\sinh(\kappa)} = \frac{\sinh(\kappa) - \kappa\cosh(\kappa)}{\kappa\sinh(\kappa)}$$
64+
65+
We can find the limit by expanding the hyperbolic functions into their Taylor series around $\kappa = 0$:
66+
67+
$$\sinh(\kappa) = \kappa + \frac{\kappa^3}{3!} + \frac{\kappa^5}{5!} + O(\kappa^7)$$
68+
69+
$$\cosh(\kappa) = 1 + \frac{\kappa^2}{2!} + \frac{\kappa^4}{4!} + O(\kappa^6)$$
70+
71+
Substituting these into the numerator and denominator:
72+
73+
**Numerator:**
74+
75+
$$\sinh(\kappa) - \kappa\cosh(\kappa) = \left(\kappa + \frac{\kappa^3}{6} + \dots\right) - \kappa\left(1 + \frac{\kappa^2}{2} + \dots\right)$$
76+
77+
$$= \left(\kappa + \frac{\kappa^3}{6}\right) - \left(\kappa + \frac{\kappa^3}{2}\right) + \dots = \left(\frac{1}{6} - \frac{1}{2}\right)\kappa^3 + \dots = -\frac{1}{3}\kappa^3 + O(\kappa^5)$$
78+
79+
**Denominator:**
80+
81+
$$\kappa\sinh(\kappa) = \kappa\left(\kappa + \frac{\kappa^3}{6} + \dots\right) = \kappa^2 + \frac{\kappa^4}{6} + O(\kappa^6)$$
82+
83+
Now, the limit of the function is:
84+
85+
$$\lim_{\kappa \to 0} f(\kappa) = \lim_{\kappa \to 0} \frac{-\frac{1}{3}\kappa^3 + O(\kappa^5)}{\kappa^2 + O(\kappa^4)} = \lim_{\kappa \to 0} \frac{\kappa^3(-\frac{1}{3} + O(\kappa^2))}{\kappa^2(1 + O(\kappa^2))} = \lim_{\kappa \to 0} \kappa \frac{-\frac{1}{3} + O(\kappa^2)}{1 + O(\kappa^2)} = 0$$
86+
87+
Since the limit exists and is finite, the singularity at $\kappa=0$ is removable. We can define $f(0) = 0$, and thus the Taylor series for $f(\kappa)$ around $\kappa=0$ is **well-defined**.
88+
89+
### Alternating Series
90+
91+
To show the series is alternating, we derive its form using the known series expansion for $\kappa \coth(\kappa)$ which involves the **Bernoulli numbers**, $B_{2n}$. The expansion is:
92+
93+
$$\kappa \coth(\kappa) = \sum_{n=0}^{\infty} \frac{B_{2n} (2\kappa)^{2n}}{(2n)!} = 1 + \frac{1}{3}\kappa^2 - \frac{1}{45}\kappa^4 + \frac{2}{945}\kappa^6 - \dots$$
94+
95+
Our function can be written as $f(\kappa) = \frac{1}{\kappa} - \coth(\kappa) = \frac{1 - \kappa \coth(\kappa)}{\kappa}$. Substituting the series:
96+
97+
$$f(\kappa) = \frac{1}{\kappa} \left( 1 - \sum_{n=0}^{\infty} \frac{B_{2n} (2\kappa)^{2n}}{(2n)!} \right)$$
98+
99+
The first term of the sum (for $n=0$) is $\frac{B_0 (2\kappa)^0}{0!} = 1$, since $B_0 = 1$.
100+
101+
$$f(\kappa) = \frac{1}{\kappa} \left( 1 - \left(1 + \sum_{n=1}^{\infty} \frac{B_{2n} (2\kappa)^{2n}}{(2n)!} \right) \right) = \frac{1}{\kappa} \left( - \sum_{n=1}^{\infty} \frac{B_{2n} 2^{2n} \kappa^{2n}}{(2n)!} \right)$$
102+
103+
$$f(\kappa) = - \sum_{n=1}^{\infty} \frac{B_{2n} 2^{2n} \kappa^{2n-1}}{(2n)!}$$
104+
105+
The sign of the Bernoulli numbers $B_{2n}$ for $n \ge 1$ alternates according to the formula $\text{sgn}(B_{2n}) = (-1)^{n-1}$. We can write $B_{2n} = (-1)^{n-1} |B_{2n}|$. Substituting this into the series for $f(\kappa)$:
106+
107+
$$f(\kappa) = - \sum_{n=1}^{\infty} \frac{(-1)^{n-1} |B_{2n}| 2^{2n} \kappa^{2n-1}}{(2n)!} = \sum_{n=1}^{\infty} (-1)^n \frac{|B_{2n}| 2^{2n} \kappa^{2n-1}}{(2n)!}$$
108+
109+
Let's write out the first few terms of the series:
110+
111+
$$f(\kappa) = -\frac{|B_2| 2^2}{(2)!}\kappa^1 + \frac{|B_4| 2^4}{(4)!}\kappa^3 - \frac{|B_6| 2^6}{(6)!}\kappa^5 + \dots$$
112+
113+
With $B_2=1/6$, $B_4=-1/30$, $B_6=1/42$:
114+
115+
$$f(\kappa) = -\frac{1/6 \cdot 4}{2}\kappa + \frac{1/30 \cdot 16}{24}\kappa^3 - \frac{1/42 \cdot 64}{720}\kappa^5 + \dots = -\frac{1}{3}\kappa + \frac{1}{45}\kappa^3 - \frac{2}{945}\kappa^5 + \dots$$
116+
117+
The series for $f(\kappa)$ contains only odd powers of $\kappa$. The coefficient of the term $\kappa^{2n-1}$ is:
118+
119+
$$c_{2n-1} = (-1)^n \frac{|B_{2n}| 2^{2n}}{(2n)!}$$
120+
121+
The coefficient of the next non-zero term, $\kappa^{2(n+1)-1} = \kappa^{2n+1}$, is:
122+
123+
$$c_{2n+1} = (-1)^{n+1} \frac{|B_{2(n+1)}| 2^{2(n+1)}}{(2(n+1))!}$$
124+
125+
Since $|B_{2k}|$ is positive for all $k \ge 1$, the sign of the coefficient $c_{2n-1}$ is determined by $(-1)^n$, and the sign of $c_{2n+1}$ is determined by $(-1)^{n+1}$. Clearly, these are opposite. Therefore, the coefficients of successive non-zero terms have alternating signs. This proves that the Taylor series for $f(\kappa)$ at $\kappa=0$ is an **alternating series**.
126+
127+
### Error Bound
128+
129+
The Taylor series derived for the function $f(\kappa)$ is an alternating series, meaning the signs of successive terms alternate. For such series, the error from truncating the series can be estimated using the **Alternating Series Estimation Theorem**. This theorem states that if you approximate the sum of a convergent alternating series by its $N$-th partial sum, the absolute value of the error (the remainder) is less than or equal to the absolute value of the first neglected term. This holds true under the conditions that the absolute values of the terms are monotonically decreasing and approach zero.
130+
131+
These conditions are met by the series for $f(\kappa)$ within its radius of convergence ($|\kappa| < \pi$). The magnitude of the general term, $\frac{|B_{2n}| 2^{2n} |\kappa|^{2n-1}}{(2n)!}$, tends to zero as $n \to \infty$, ensuring convergence. Moreover, for any given $\kappa$ in this interval, the magnitudes of the terms will eventually be monotonically decreasing. For sufficiently small values of $\kappa$, this decreasing trend holds from the very first term. Therefore, the error in approximating the function with the sum of its first $N$ terms is bounded by the magnitude of the $(N+1)$-th term. For example, the error in the approximation $f(\kappa) \approx -\frac{1}{3}\kappa$ is less than or equal to the next term's magnitude, $\frac{1}{45}|\kappa|^3$. Thus:
132+
133+
$$\left\lvert \kappa \right\rvert < 10^{-6} \implies \varepsilon \lesssim \mathcal{O}(10^{-21})$$
134+
135+
---
136+
137+
## Implementation in GraphNeT
138+
139+
### Problem Statement
140+
The mathematical derivation above provides exact formulas for computing gradients in the von Mises-Fisher loss function. However, a naive implementation would encounter **division by zero errors** when $\kappa = 0$, even though the mathematical limit is well-defined. This creates RuntimeWarnings and potential numerical instability.
141+
142+
### Numerical Challenge
143+
The core issue arises in the backward pass when computing:
144+
145+
$$\frac{\partial}{\partial \kappa} \log C_3(\kappa) = \frac{1}{\kappa} - \frac{1}{\tanh(\kappa)}$$
146+
147+
For $\kappa = 0$:
148+
- Both $\frac{1}{\kappa}$ and $\frac{1}{\tanh(\kappa)}$ diverge individually
149+
- Their difference converges to the finite limit $\lim_{\kappa \to 0} f(\kappa) = 0$- Standard floating-point evaluation triggers division by zero warnings
150+
151+
### Solution Strategy
152+
Our implementation uses **boolean masking** to conditionally apply different computational approaches:
153+
154+
```python
155+
# Initialize gradient array
156+
grads = np.zeros_like(kappa)
157+
158+
# Handle small kappa values (including zero)
159+
small_mask = np.abs(kappa) < 1e-6
160+
grads[small_mask] = -kappa[small_mask] / 3
161+
162+
# Handle large kappa values
163+
large_mask = ~small_mask
164+
if np.any(large_mask):
165+
kappa_large = kappa[large_mask]
166+
grads[large_mask] = 1/kappa_large - 1/np.tanh(kappa_large)
167+
```
168+
169+
### Key Implementation Features
170+
171+
1. **Threshold Selection**: $|\kappa| < 10^{-6}$ - Based on error analysis: truncation error $\leq \frac{|\kappa|^3}{45} \approx 10^{-21}$ - Well below machine precision for typical floating-point arithmetic
172+
- Provides excellent numerical accuracy
173+
174+
2. **Boolean Masking vs. np.where()**
175+
- **Avoided**: `np.where(condition, small_branch, large_branch)`
176+
- Evaluates both branches, still triggers warnings
177+
- **Used**: Boolean indexing with separate computations
178+
- Only evaluates necessary expressions
179+
- Eliminates all RuntimeWarnings
180+
181+
3. **Mathematical Consistency**
182+
- Small $\kappa$: Uses first-order Taylor approximation $-\kappa/3$ - Large $\kappa$: Uses exact formula $\frac{1}{\kappa} - \frac{1}{\tanh(\kappa)}$ - Seamless transition preserves continuity and differentiability
183+
184+
4. **Edge Case Handling**
185+
- $\kappa = 0$: Returns exactly $0$ (no approximation needed)
186+
- Multiple zeros in batch: Each handled independently
187+
- Mixed arrays: Efficient vectorized computation
188+
189+
### Verification
190+
The implementation is validated through comprehensive unit tests that verify:
191+
- ✅ No RuntimeWarnings generated during backward pass
192+
- ✅ Gradients remain finite for all input values
193+
- ✅ Mathematical accuracy: $f(0) = 0$ exactly
194+
- ✅ Correct Taylor approximation for small $|\kappa|$
195+
- ✅ Proper handling of arrays containing multiple zeros
196+
197+
This approach ensures both mathematical correctness and numerical stability, making the von Mises-Fisher loss function robust for practical deep learning applications where $\kappa$ values may include zeros or near-zero elements.

src/graphnet/training/loss_functions.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,65 @@ def forward(
270270
def backward(
271271
ctx: Any, grad_output: Tensor
272272
) -> Tensor: # pylint: disable=invalid-name,arguments-differ
273-
"""Backward pass."""
273+
"""Backward pass for LogCMK computation.
274+
275+
Mathematical Background:
276+
-----------------------
277+
For the von Mises-Fisher distribution, the gradient of log C_m(κ) with
278+
respect to κ is given by the ratio of modified Bessel functions:
279+
280+
∂/∂κ log C_m(κ) = (m/2-1)/κ - I_{m/2}(κ)/I_{m/2-1}(κ)
281+
282+
For m=3, this simplifies to the exact formula:
283+
∂/∂κ log C_3(κ) = 1/κ - 1/tanh(κ)
284+
285+
For small κ values, we use the Taylor series approximation:
286+
f(κ) = -κ/3 + κ³/45 - 2κ⁵/945 + O(κ⁷)
287+
288+
The first-order approximation -κ/3 provides sufficient accuracy for
289+
|κ| < 1e-6, with truncation error bounded by |κ|³/45 ≲ O(10⁻²¹).
290+
291+
Implementation Details:
292+
----------------------
293+
Uses boolean masking to avoid double evaluation and RuntimeWarnings:
294+
- Small κ: |κ| < 1e-6 → gradient = -κ/3 (Taylor approximation)
295+
- Large κ: |κ| ≥ 1e-6 → gradient = 1/κ - 1/tanh(κ) (exact formula)
296+
297+
References:
298+
----------
299+
[1] von Mises-Fisher distribution: Wikipedia
300+
[2] arXiv:1812.04616, Section 8.2
301+
[3] MIT License (c) 2019 Max Ryabinin - Modified for GraphNeT
302+
303+
Args:
304+
ctx: Autograd context containing saved tensors and metadata.
305+
grad_output: Gradient with respect to the output tensor.
306+
307+
Returns:
308+
Tuple of gradients: (None for m, gradient w.r.t. κ).
309+
"""
274310
kappa = ctx.saved_tensors[0]
275311
m = ctx.m
276312
dtype = ctx.dtype
277313
kappa = kappa.double().cpu().numpy()
278-
grads = -(
279-
(scipy.special.iv(m / 2.0, kappa))
280-
/ (scipy.special.iv(m / 2.0 - 1, kappa))
281-
)
314+
if np.isclose(m, 3, atol=1e-6):
315+
# Initialize gradient array
316+
grads = np.zeros_like(kappa)
317+
318+
# Handle small kappa values (including zero) to avoid division by zero
319+
small_mask = np.abs(kappa) < 1e-6
320+
grads[small_mask] = -kappa[small_mask] / 3
321+
322+
# Handle large kappa values
323+
large_mask = ~small_mask
324+
if np.any(large_mask):
325+
kappa_large = kappa[large_mask]
326+
grads[large_mask] = 1/kappa_large - 1/np.tanh(kappa_large)
327+
else:
328+
grads = -(
329+
(scipy.special.iv(m / 2.0, kappa))
330+
/ (scipy.special.iv(m / 2.0 - 1, kappa))
331+
)
282332
return (
283333
None,
284334
grad_output

tests/training/test_loss_functions.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6+
import warnings
67
from torch import Tensor
78
from torch.autograd import grad
89

@@ -182,3 +183,93 @@ def test_von_mises_fisher_approximation_large_kappa(
182183
assert torch.allclose(
183184
grads_approx[exact_is_valid], grads_exact[exact_is_valid], rtol=1e-2
184185
)
186+
187+
188+
def test_logcmk_backward_zero_handling(dtype: torch.dtype = torch.float64) -> None:
189+
"""Test LogCMK backward pass handles arrays with zero values correctly.
190+
191+
This test ensures that the LogCMK.backward method correctly handles cases
192+
where the kappa tensor contains zero values without raising division by zero
193+
errors or warnings. The implementation uses boolean masking to conditionally
194+
apply different formulas for small (including zero) and large kappa values,
195+
avoiding double evaluation that would cause RuntimeWarnings.
196+
197+
Args:
198+
dtype: PyTorch data type for the test tensors.
199+
"""
200+
# Test parameters
201+
m = 3 # Dimension for which we have the exact formula
202+
203+
# Create kappa tensor with zeros and other values, including edge cases
204+
kappa_values = [0.0, 1e-7, 1e-6, 1e-5, 0.1, 1.0, 10.0]
205+
kappa = torch.tensor(kappa_values, dtype=dtype, requires_grad=True)
206+
207+
# Forward pass using VonMisesFisherLoss.log_cmk_exact which internally uses LogCMK
208+
result = VonMisesFisherLoss.log_cmk_exact(m, kappa)
209+
210+
# Test that backward pass doesn't raise any errors or warnings
211+
# Capture warnings to ensure no RuntimeWarnings are generated
212+
with warnings.catch_warnings(record=True) as caught_warnings:
213+
warnings.simplefilter("always")
214+
try:
215+
grads = torch.autograd.grad(
216+
outputs=result.sum(),
217+
inputs=kappa,
218+
grad_outputs=None,
219+
retain_graph=False,
220+
create_graph=False,
221+
)[0]
222+
backward_success = True
223+
error_msg = ""
224+
except (ZeroDivisionError, RuntimeWarning) as e:
225+
backward_success = False
226+
error_msg = str(e)
227+
228+
# Verify no errors occurred
229+
assert backward_success, f"Backward pass failed with error: {error_msg}"
230+
231+
# Verify no RuntimeWarnings were generated
232+
runtime_warnings = [w for w in caught_warnings if issubclass(w.category, RuntimeWarning)]
233+
assert len(runtime_warnings) == 0, f"RuntimeWarnings were generated: {[str(w.message) for w in runtime_warnings]}"
234+
235+
# Verify gradients are finite
236+
assert torch.all(torch.isfinite(grads)), "Gradients should be finite for all kappa values"
237+
238+
# Test specific values for correctness
239+
# For kappa=0, the gradient should be -kappa/3 = 0
240+
zero_idx = 0 # Index where kappa=0
241+
assert torch.isclose(grads[zero_idx], torch.tensor(0.0, dtype=dtype)), \
242+
f"Gradient at kappa=0 should be 0, got {grads[zero_idx]}"
243+
244+
# For very small kappa (1e-7), should use -kappa/3 approximation
245+
small_kappa_idx = 1 # Index where kappa=1e-7
246+
expected_small_grad = -kappa_values[small_kappa_idx] / 3
247+
assert torch.isclose(grads[small_kappa_idx], torch.tensor(expected_small_grad, dtype=dtype), atol=1e-10), \
248+
"Gradient for small kappa should use -kappa/3 approximation"
249+
250+
# Test with array containing multiple zeros
251+
kappa_multi_zero = torch.tensor([0.0, 0.0, 1.0, 0.0, 10.0], dtype=dtype, requires_grad=True)
252+
result_multi = VonMisesFisherLoss.log_cmk_exact(m, kappa_multi_zero)
253+
254+
with warnings.catch_warnings(record=True) as caught_warnings_multi:
255+
warnings.simplefilter("always")
256+
try:
257+
grads_multi = torch.autograd.grad(
258+
outputs=result_multi.sum(),
259+
inputs=kappa_multi_zero,
260+
grad_outputs=None,
261+
)[0]
262+
multi_zero_success = True
263+
except (ZeroDivisionError, RuntimeWarning):
264+
multi_zero_success = False
265+
266+
assert multi_zero_success, "Should handle arrays with multiple zero values"
267+
assert torch.all(torch.isfinite(grads_multi)), "All gradients should be finite with multiple zeros"
268+
269+
# Verify no RuntimeWarnings for multiple zeros case
270+
runtime_warnings_multi = [w for w in caught_warnings_multi if issubclass(w.category, RuntimeWarning)]
271+
assert len(runtime_warnings_multi) == 0, f"RuntimeWarnings were generated with multiple zeros: {[str(w.message) for w in runtime_warnings_multi]}"
272+
273+
# Verify that zero elements have zero gradients
274+
zero_mask = kappa_multi_zero == 0.0
275+
assert torch.all(grads_multi[zero_mask] == 0.0), "Zero kappa values should have zero gradients"

0 commit comments

Comments
 (0)