Skip to content

Commit c63a445

Browse files
committed
Make # abs tol compare over the complex numbers
For calculations over complex numbers that generate numeric noise, one tends to create small but non-zero imaginary parts. This PR updates the "# abs tol" tolerance setting to work over the complex numbers, as the "abs" suggests complex numbers. The real and imaginary parts are compared separately. The ordinary "# tol" and "# rel tol" are left as is. Fixes #36631
1 parent 79c047c commit c63a445

File tree

5 files changed

+510
-207
lines changed

5 files changed

+510
-207
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""
2+
Check tolerance when parsing docstrings
3+
"""
4+
5+
# ****************************************************************************
6+
# Copyright (C) 2012-2018 David Roe <[email protected]>
7+
# 2012 Robert Bradshaw <[email protected]>
8+
# 2012 William Stein <[email protected]>
9+
# 2013 R. Andrew Ohana
10+
# 2013 Volker Braun
11+
# 2013-2018 Jeroen Demeyer <[email protected]>
12+
# 2016-2021 Frédéric Chapoton
13+
# 2017-2018 Erik M. Bray
14+
# 2020 Marc Mezzarobba
15+
# 2020-2023 Matthias Koeppe
16+
# 2022 John H. Palmieri
17+
# 2022 Sébastien Labbé
18+
# 2023 Kwankyu Lee
19+
#
20+
# Distributed under the terms of the GNU General Public License (GPL)
21+
# as published by the Free Software Foundation; either version 2 of
22+
# the License, or (at your option) any later version.
23+
# https://www.gnu.org/licenses/
24+
# ****************************************************************************
25+
26+
import re
27+
from sage.doctest.rif_tol import RIFtol, add_tolerance
28+
from sage.doctest.marked_output import MarkedOutput
29+
30+
31+
# Regex pattern for float without the (optional) leading sign
32+
float_without_sign = r'((\d*\.?\d+)|(\d+\.?))([eE][+-]?\d+)?'
33+
34+
35+
# Regular expression for floats
36+
float_regex = re.compile(r'\s*([+-]?\s*' + float_without_sign + r')')
37+
38+
39+
class ToleranceExceededError(BaseException):
40+
pass
41+
42+
43+
def check_tolerance_real_domain(want: MarkedOutput, got: str) -> tuple[str, str]:
44+
"""
45+
Compare want and got over real domain with tolerance
46+
47+
INPUT:
48+
49+
- ``want`` -- a string, what you want
50+
- ``got`` -- a string, what you got
51+
52+
OUTPUT:
53+
54+
The strings to compare, but with matching float numbers replaced by asterisk.
55+
56+
EXAMPLES::
57+
58+
sage: from sage.doctest.check_tolerance import check_tolerance_real_domain
59+
sage: from sage.doctest.marked_output import MarkedOutput
60+
sage: check_tolerance_real_domain(
61+
....: MarkedOutput('foo:0.2').update(abs_tol=0.3),
62+
....: 'bar:0.4')
63+
['foo:*', 'bar:*']
64+
sage: check_tolerance_real_domain(
65+
....: MarkedOutput('foo:0.2').update(abs_tol=0.3),
66+
....: 'bar:0.6')
67+
Traceback (most recent call last):
68+
...
69+
sage.doctest.check_tolerance.ToleranceExceededError
70+
"""
71+
# First check that the number of occurrences of floats appearing match
72+
want_str = [g[0] for g in float_regex.findall(want)]
73+
got_str = [g[0] for g in float_regex.findall(got)]
74+
if len(want_str) != len(got_str):
75+
raise ToleranceExceededError()
76+
77+
# Then check the numbers
78+
want_values = [RIFtol(g) for g in want_str]
79+
want_intervals = [add_tolerance(v, want) for v in want_values]
80+
got_values = [RIFtol(g) for g in got_str]
81+
# The doctest is not successful if one of the "want" and "got"
82+
# intervals have an empty intersection
83+
if not all(a.overlaps(b) for a, b in zip(want_intervals, got_values)):
84+
raise ToleranceExceededError()
85+
86+
# Then check the part of the doctests without the numbers
87+
# Continue the check process with floats replaced by stars
88+
want = float_regex.sub('*', want)
89+
got = float_regex.sub('*', got)
90+
return [want, got]
91+
92+
93+
# match 1.0 or 1.0 + I or 1.0 + 2.0*I
94+
real_plus_optional_imag = ''.join([
95+
r'\s*(?P<real>[+-]?\s*',
96+
float_without_sign,
97+
r')(\s*(?P<real_imag_coeff>[+-]\s*',
98+
float_without_sign,
99+
r')\*I|\s*(?P<real_imag_unit>[+-])\s*I)?',
100+
])
101+
102+
103+
# match - 2.0*I
104+
only_imag = ''.join([
105+
r'\s*(?P<only_imag>[+-]?\s*',
106+
float_without_sign,
107+
r')\*I',
108+
])
109+
110+
111+
# match I or -I (no digits), require a non-word part before and after for specificity
112+
imaginary_unit = r'(?P<unit_imag_pre>^|\W)(?P<unit_imag>[+-]?)I(?P<unit_imag_post>$|\W)'
113+
114+
115+
complex_regex = re.compile(''.join([
116+
'(',
117+
only_imag,
118+
'|',
119+
imaginary_unit,
120+
'|',
121+
real_plus_optional_imag,
122+
')',
123+
]))
124+
125+
126+
def complex_match_to_real_and_imag(m: re.Match) -> tuple[str, str]:
127+
"""
128+
Extract real and imaginary part from match
129+
130+
INPUT:
131+
132+
- ``m`` -- match from ``complex_regex``
133+
134+
OUTPUT:
135+
136+
Pair of real and complex parts (as string)
137+
138+
EXAMPLES::
139+
140+
sage: from sage.doctest.check_tolerance import complex_match_to_real_and_imag, complex_regex
141+
sage: complex_match_to_real_and_imag(complex_regex.match('1.0'))
142+
('1.0', '0')
143+
sage: complex_match_to_real_and_imag(complex_regex.match('-1.0 - I'))
144+
('-1.0', '-1')
145+
sage: complex_match_to_real_and_imag(complex_regex.match('1.0 - 3.0*I'))
146+
('1.0', '- 3.0')
147+
sage: complex_match_to_real_and_imag(complex_regex.match('1.0*I'))
148+
('0', '1.0')
149+
sage: complex_match_to_real_and_imag(complex_regex.match('- 2.0*I'))
150+
('0', '- 2.0')
151+
sage: complex_match_to_real_and_imag(complex_regex.match('-I'))
152+
('0', '-1')
153+
sage: for match in complex_regex.finditer('[1, -1, I, -1, -I]'):
154+
....: print(complex_match_to_real_and_imag(match))
155+
('1', '0')
156+
('-1', '0')
157+
('0', '1')
158+
('-1', '0')
159+
('0', '-1')
160+
sage: for match in complex_regex.finditer('[1, -1.3, -1.5 + 0.1*I, 0.5 - 0.1*I, -1.5*I]'):
161+
....: print(complex_match_to_real_and_imag(match))
162+
('1', '0')
163+
('-1.3', '0')
164+
('-1.5', '+ 0.1')
165+
('0.5', '- 0.1')
166+
('0', '-1.5')
167+
"""
168+
real = m.group('real')
169+
if real is not None:
170+
real_imag_coeff = m.group('real_imag_coeff')
171+
real_imag_unit = m.group('real_imag_unit')
172+
if real_imag_coeff is not None:
173+
return (real, real_imag_coeff)
174+
elif real_imag_unit is not None:
175+
return (real, real_imag_unit + '1')
176+
else:
177+
return (real, '0')
178+
only_imag = m.group('only_imag')
179+
if only_imag is not None:
180+
return ('0', only_imag)
181+
unit_imag = m.group('unit_imag')
182+
if unit_imag is not None:
183+
return ('0', unit_imag + '1')
184+
assert False, 'unreachable'
185+
186+
187+
def complex_star_repl(m: re.Match):
188+
"""
189+
Replace the complex number in the match with '*'
190+
"""
191+
if m.group('unit_imag') is not None:
192+
# preserve the matched non-word part
193+
return ''.join([
194+
(m.group('unit_imag_pre') or '').strip(),
195+
'*',
196+
(m.group('unit_imag_post') or '').strip(),
197+
])
198+
else:
199+
return '*'
200+
201+
202+
def check_tolerance_complex_domain(want: MarkedOutput, got: str) -> tuple[str, str]:
203+
"""
204+
Compare want and got over complex domain with tolerance
205+
206+
INPUT:
207+
208+
- ``want`` -- a string, what you want
209+
- ``got`` -- a string, what you got
210+
211+
OUTPUT:
212+
213+
The strings to compare, but with matching complex numbers replaced by asterisk.
214+
215+
EXAMPLES::
216+
217+
sage: from sage.doctest.check_tolerance import check_tolerance_complex_domain
218+
sage: from sage.doctest.marked_output import MarkedOutput
219+
sage: check_tolerance_complex_domain(
220+
....: MarkedOutput('foo:[0.2 + 0.1*I]').update(abs_tol=0.3),
221+
....: 'bar:[0.4]')
222+
['foo:[*]', 'bar:[*]']
223+
sage: check_tolerance_complex_domain(
224+
....: MarkedOutput('foo:-0.5 - 0.1*I').update(abs_tol=2),
225+
....: 'bar:1')
226+
['foo:*', 'bar:*']
227+
sage: check_tolerance_complex_domain(
228+
....: MarkedOutput('foo:[1.0*I]').update(abs_tol=0.3),
229+
....: 'bar:[I]')
230+
['foo:[*]', 'bar:[*]']
231+
sage: check_tolerance_complex_domain(MarkedOutput('foo:0.2 + 0.1*I').update(abs_tol=0.3), 'bar:0.6')
232+
Traceback (most recent call last):
233+
...
234+
sage.doctest.check_tolerance.ToleranceExceededError
235+
"""
236+
want_str = []
237+
for match in complex_regex.finditer(want):
238+
want_str.extend(complex_match_to_real_and_imag(match))
239+
got_str = []
240+
for match in complex_regex.finditer(got):
241+
got_str.extend(complex_match_to_real_and_imag(match))
242+
if len(want_str) != len(got_str):
243+
raise ToleranceExceededError()
244+
245+
# Then check the numbers
246+
want_values = [RIFtol(g) for g in want_str]
247+
want_intervals = [add_tolerance(v, want) for v in want_values]
248+
got_values = [RIFtol(g) for g in got_str]
249+
# The doctest is not successful if one of the "want" and "got"
250+
# intervals have an empty intersection
251+
if not all(a.overlaps(b) for a, b in zip(want_intervals, got_values)):
252+
raise ToleranceExceededError()
253+
254+
# Then check the part of the doctests without the numbers
255+
# Continue the check process with floats replaced by stars
256+
want = complex_regex.sub(complex_star_repl, want)
257+
got = complex_regex.sub(complex_star_repl, got)
258+
return [want, got]

src/sage/doctest/control.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def expand_files_into_sources(self):
971971
sage: DC = DocTestController(DD, [dirname])
972972
sage: DC.expand_files_into_sources()
973973
sage: len(DC.sources)
974-
12
974+
15
975975
sage: DC.sources[0].options.optional
976976
True
977977
@@ -1072,13 +1072,16 @@ def sort_sources(self):
10721072
sage.doctest.util
10731073
sage.doctest.test
10741074
sage.doctest.sources
1075+
sage.doctest.rif_tol
10751076
sage.doctest.reporting
10761077
sage.doctest.parsing_test
10771078
sage.doctest.parsing
1079+
sage.doctest.marked_output
10781080
sage.doctest.forker
10791081
sage.doctest.fixtures
10801082
sage.doctest.external
10811083
sage.doctest.control
1084+
sage.doctest.check_tolerance
10821085
sage.doctest.all
10831086
sage.doctest
10841087
"""

src/sage/doctest/marked_output.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Helper for attaching tolerance information to strings
3+
"""
4+
5+
# ****************************************************************************
6+
# Copyright (C) 2012-2018 David Roe <[email protected]>
7+
# 2012 Robert Bradshaw <[email protected]>
8+
# 2012 William Stein <[email protected]>
9+
# 2013 R. Andrew Ohana
10+
# 2013 Volker Braun
11+
# 2013-2018 Jeroen Demeyer <[email protected]>
12+
# 2016-2021 Frédéric Chapoton
13+
# 2017-2018 Erik M. Bray
14+
# 2020 Marc Mezzarobba
15+
# 2020-2023 Matthias Koeppe
16+
# 2022 John H. Palmieri
17+
# 2022 Sébastien Labbé
18+
# 2023 Kwankyu Lee
19+
#
20+
# Distributed under the terms of the GNU General Public License (GPL)
21+
# as published by the Free Software Foundation; either version 2 of
22+
# the License, or (at your option) any later version.
23+
# https://www.gnu.org/licenses/
24+
# ****************************************************************************
25+
26+
27+
class MarkedOutput(str):
28+
"""
29+
A subclass of string with context for whether another string
30+
matches it.
31+
32+
EXAMPLES::
33+
34+
sage: from sage.doctest.marked_output import MarkedOutput
35+
sage: s = MarkedOutput("abc")
36+
sage: s.rel_tol
37+
0
38+
sage: s.update(rel_tol = .05)
39+
'abc'
40+
sage: s.rel_tol
41+
0.0500000000000000
42+
43+
sage: MarkedOutput("56 µs")
44+
'56 \xb5s'
45+
"""
46+
random = False
47+
rel_tol = 0
48+
abs_tol = 0
49+
tol = 0
50+
51+
def update(self, **kwds):
52+
"""
53+
EXAMPLES::
54+
55+
sage: from sage.doctest.marked_output import MarkedOutput
56+
sage: s = MarkedOutput("0.0007401")
57+
sage: s.update(abs_tol = .0000001)
58+
'0.0007401'
59+
sage: s.rel_tol
60+
0
61+
sage: s.abs_tol
62+
1.00000000000000e-7
63+
"""
64+
self.__dict__.update(kwds)
65+
return self
66+
67+
def __reduce__(self):
68+
"""
69+
Pickling.
70+
71+
EXAMPLES::
72+
73+
sage: from sage.doctest.marked_output import MarkedOutput
74+
sage: s = MarkedOutput("0.0007401")
75+
sage: s.update(abs_tol = .0000001)
76+
'0.0007401'
77+
sage: t = loads(dumps(s)) # indirect doctest
78+
sage: t == s
79+
True
80+
sage: t.abs_tol
81+
1.00000000000000e-7
82+
"""
83+
return make_marked_output, (str(self), self.__dict__)
84+
85+
86+
def make_marked_output(s, D):
87+
"""
88+
Auxiliary function for pickling.
89+
90+
EXAMPLES::
91+
92+
sage: from sage.doctest.marked_output import make_marked_output
93+
sage: s = make_marked_output("0.0007401", {'abs_tol':.0000001})
94+
sage: s
95+
'0.0007401'
96+
sage: s.abs_tol
97+
1.00000000000000e-7
98+
"""
99+
ans = MarkedOutput(s)
100+
ans.__dict__.update(D)
101+
return ans

0 commit comments

Comments
 (0)