Skip to content

Commit c9e5c99

Browse files
committed
Allow construct RIF element from question-style string
1 parent ff9d834 commit c9e5c99

File tree

3 files changed

+251
-4
lines changed

3 files changed

+251
-4
lines changed

src/sage/libs/mpfi/__init__.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ cdef extern from "mpfi.h":
2626
int mpfi_set_z(mpfi_ptr, mpz_t)
2727
int mpfi_set_q(mpfi_ptr, mpq_t)
2828
int mpfi_set_fr(mpfi_ptr, mpfr_srcptr)
29-
int mpfi_set_str(mpfi_ptr, char *, int)
29+
int mpfi_set_str(mpfi_ptr, const char *, int)
3030

3131
# combined initialization and assignment functions
3232
int mpfi_init_set(mpfi_ptr, mpfi_srcptr)
@@ -36,7 +36,7 @@ cdef extern from "mpfi.h":
3636
int mpfi_init_set_z(mpfi_ptr, mpz_srcptr)
3737
int mpfi_init_set_q(mpfi_ptr, mpq_srcptr)
3838
int mpfi_init_set_fr(mpfi_ptr, mpfr_srcptr)
39-
int mpfi_init_set_str(mpfi_ptr, char *, int)
39+
int mpfi_init_set_str(mpfi_ptr, const char *, int)
4040

4141
# swapping two intervals
4242
void mpfi_swap(mpfi_ptr, mpfi_ptr)

src/sage/libs/mpfr/__init__.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ cdef extern from "mpfr.h":
2727
# int mpfr_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
2828
int mpfr_set_ui_2exp(mpfr_t rop, unsigned long int op, mp_exp_t e, mpfr_rnd_t rnd)
2929
int mpfr_set_si_2exp(mpfr_t rop, long int op, mp_exp_t e, mpfr_rnd_t rnd)
30-
int mpfr_set_str(mpfr_t rop, char *s, int base, mpfr_rnd_t rnd)
30+
int mpfr_set_str(mpfr_t rop, const char *s, int base, mpfr_rnd_t rnd)
3131
int mpfr_strtofr(mpfr_t rop, char *nptr, char **endptr, int base, mpfr_rnd_t rnd)
3232
void mpfr_set_inf(mpfr_t x, int sign)
3333
void mpfr_set_nan(mpfr_t x)
@@ -43,7 +43,7 @@ cdef extern from "mpfr.h":
4343
int mpfr_init_set_z(mpfr_t rop, mpz_t op, mpfr_rnd_t rnd)
4444
int mpfr_init_set_q(mpfr_t rop, mpq_t op, mpfr_rnd_t rnd)
4545
# int mpfr_init_set_f(mpfr_t rop, mpf_t op, mpfr_rnd_t rnd)
46-
int mpfr_init_set_str(mpfr_t x, char *s, int base, mpfr_rnd_t rnd)
46+
int mpfr_init_set_str(mpfr_t x, const char *s, int base, mpfr_rnd_t rnd)
4747

4848
# Conversion Functions
4949
double mpfr_get_d(mpfr_t op, mpfr_rnd_t rnd)

src/sage/rings/convert/mpfi.pyx

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@ Convert Sage/Python objects to real/complex intervals
1111
# http://www.gnu.org/licenses/
1212
#*****************************************************************************
1313

14+
import re
15+
1416
from cpython.float cimport PyFloat_AS_DOUBLE
1517
from cpython.complex cimport PyComplex_RealAsDouble, PyComplex_ImagAsDouble
1618

19+
from libc.stdio cimport printf
20+
1721
from sage.libs.mpfr cimport *
1822
from sage.libs.mpfi cimport *
23+
from sage.libs.gmp.mpz cimport *
1924
from sage.libs.gsl.complex cimport *
2025

2126
from sage.arith.long cimport integer_check_long
@@ -45,6 +50,243 @@ cdef inline int return_real(mpfi_ptr im) noexcept:
4550
return 0
4651

4752

53+
NUMBER = re.compile(rb'([+-]?(0[XxBb])?[0-9A-Za-z]+)\.([0-9A-Za-z]*)\?([0-9]*)(?:([EePp@])([+-]?[0-9]+))?')
54+
# example: -0xABC.DEF?12@5
55+
# match groups: (-0xABC) (0x) (DEF) (12) (@) (5)
56+
57+
cdef int _from_str_question_style(mpfi_ptr x, bytes s, int base) except -1:
58+
"""
59+
Convert a string in question style to an MPFI interval.
60+
61+
INPUT:
62+
63+
- ``x`` -- a pre-initialized MPFI interval
64+
65+
- ``s`` -- the string to convert
66+
67+
- ``base`` -- base to use for string conversion
68+
69+
OUTPUT:
70+
71+
- if conversion is possible: set ``x`` and return 0.
72+
73+
- in all other cases: return some nonzero value, or raise an exception.
74+
75+
TESTS:
76+
77+
Double check that ``ZZ``, ``RR`` and ``RIF`` follows the string
78+
conversion rule for base different from `10` (except ``ZZ``
79+
which only allows base up to `36`)::
80+
81+
sage: ZZ("0x123", base=0)
82+
291
83+
sage: RR("0x123.e1", base=0) # rel tol 1e-12
84+
291.878906250000
85+
sage: RR("0x123.@1", base=0) # rel tol 1e-12
86+
4656.00000000000
87+
sage: RIF("0x123.4@1", base=0)
88+
4660
89+
sage: ZZ("1Xx", base=36) # case insensitive
90+
2517
91+
sage: ZZ("1Xx", base=62)
92+
Traceback (most recent call last):
93+
...
94+
ValueError: base (=62) must be 0 or between 2 and 36
95+
sage: RR("1Xx", base=36) # rel tol 1e-12
96+
2517.00000000000
97+
sage: RR("0x123", base=36) # rel tol 1e-12
98+
1.54101900000000e6
99+
sage: RR("-1Xx@-1", base=62) # rel tol 1e-12
100+
-95.9516129032258
101+
sage: RIF("1Xx@-1", base=62) # rel tol 1e-12
102+
95.95161290322580?
103+
sage: RIF("1aE1", base=11)
104+
Traceback (most recent call last):
105+
...
106+
TypeError: unable to convert '1aE1' to real interval
107+
sage: RIF("1aE1", base=11)
108+
Traceback (most recent call last):
109+
...
110+
TypeError: unable to convert '1aE1' to real interval
111+
112+
General checks::
113+
114+
sage: RIF("123456.?2").endpoints() # rel tol 1e-12
115+
(123454.0, 123458.0)
116+
sage: RIF("1234.56?2").endpoints() # rel tol 1e-12
117+
(1234.54, 1234.58)
118+
sage: RIF("1234.56?2e2").endpoints() # rel tol 1e-12
119+
(123454.0, 123458.0)
120+
sage: x = RIF("-1234.56?2e2"); x.endpoints() # rel tol 1e-12
121+
(-123458.0, -123454.0)
122+
sage: x
123+
-1.2346?e5
124+
sage: x.str(style="question", error_digits=1)
125+
'-123456.?2'
126+
sage: RIF("1.?100").endpoints() # rel tol 1e-12
127+
(-99.0, 101.0)
128+
sage: RIF("1.?100").str(style="question", error_digits=3)
129+
'1.?100'
130+
131+
Large exponent (ensure precision is not lost)::
132+
133+
sage: x = RIF("1.123456?2e1000000000"); x
134+
1.12346?e1000000000
135+
sage: x.str(style="question", error_digits=3)
136+
'1.12345600?201e1000000000'
137+
138+
Large precision::
139+
140+
sage: F = RealIntervalField(1000)
141+
sage: x = F(sqrt(2)); x.endpoints() # rel tol 1e-290
142+
(1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140798,
143+
1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147010955997160597027453459686201472851741864088919860955232923048430871432145083976260362799525140799)
144+
sage: x in F(x.str(style="question", error_digits=3))
145+
True
146+
sage: x in F(x.str(style="question", error_digits=0))
147+
True
148+
sage: F("1.123456789123456789123456789123456789123456789123456789123456789123456789?987654321987654321987654321e500").endpoints() # rel tol 1e-290
149+
(1.123456789123456789123456789123456789123456788135802467135802467135802468e500,
150+
1.12345678912345678912345678912345678912345679011111111111111111111111111e500)
151+
152+
Stress test::
153+
154+
sage: for F in [RealIntervalField(15), RIF, RealIntervalField(100), RealIntervalField(1000)]:
155+
....: for i in range(1000):
156+
....: a, b = randint(-10^9, 10^9), randint(0, 50)
157+
....: c, d = randint(-2^b, 2^b), randint(2, 5)
158+
....: x = a * F(d)^c
159+
....: assert x in F(x.str(style="question", error_digits=3)), (x, a, c, d)
160+
....: assert x in F(x.str(style="question", error_digits=0)), (x, a, c, d)
161+
162+
Base different from `10` (note that the error and exponent are specified in decimal)::
163+
164+
sage: RIF("10000.?0", base=2).endpoints() # rel tol 1e-12
165+
(16.0, 16.0)
166+
sage: RIF("10000.?0e10", base=2).endpoints() # rel tol 1e-12
167+
(16384.0, 16384.0)
168+
sage: x = RIF("10000.?10", base=2); x.endpoints() # rel tol 1e-12
169+
(6.0, 26.0)
170+
sage: x.str(base=2, style="question", error_digits=2)
171+
'10000.000?80'
172+
sage: x = RIF("10000.000?80", base=2); x.endpoints() # rel tol 1e-12
173+
(6.0, 26.0)
174+
sage: x = RIF("12a.?", base=16); x.endpoints() # rel tol 1e-12
175+
(297.0, 299.0)
176+
sage: x = RIF("12a.BcDeF?", base=16); x.endpoints() # rel tol 1e-12
177+
(298.737775802611, 298.737777709962)
178+
sage: x = RIF("12a.BcDeF?@10", base=16); x.endpoints() # rel tol 1e-12
179+
(3.28465658150911e14, 3.28465660248065e14)
180+
sage: x = RIF("12a.BcDeF?p10", base=16); x.endpoints() # rel tol 1e-12
181+
(305907.482421875, 305907.484375000)
182+
sage: x = RIF("0x12a.BcDeF?p10", base=0); x.endpoints() # rel tol 1e-12
183+
(305907.482421875, 305907.484375000)
184+
185+
Space is allowed::
186+
187+
sage: RIF("-1234.56?2").endpoints() # rel tol 1e-12
188+
(-1234.58, -1234.54)
189+
sage: RIF("- 1234.56 ?2").endpoints() # rel tol 1e-12
190+
(-1234.58, -1234.54)
191+
192+
Erroneous input::
193+
194+
sage: RIF("1234.56?2e2.3")
195+
Traceback (most recent call last):
196+
...
197+
TypeError: unable to convert '1234.56?2e2.3' to real interval
198+
sage: RIF("1234?2") # decimal point required
199+
Traceback (most recent call last):
200+
...
201+
TypeError: unable to convert '1234?2' to real interval
202+
sage: RIF("1234.?2e")
203+
Traceback (most recent call last):
204+
...
205+
TypeError: unable to convert '1234.?2e' to real interval
206+
sage: RIF("1.?e999999999999999999999999")
207+
[-infinity .. +infinity]
208+
sage: RIF("0X1.?", base=33) # X is not valid digit in base 33
209+
Traceback (most recent call last):
210+
...
211+
TypeError: unable to convert '0X1.?' to real interval
212+
sage: RIF("1.a?1e10", base=12)
213+
Traceback (most recent call last):
214+
...
215+
TypeError: unable to convert '1.a?1e10' to real interval
216+
sage: RIF("1.1?a@10", base=12)
217+
Traceback (most recent call last):
218+
...
219+
TypeError: unable to convert '1.1?a@10' to real interval
220+
sage: RIF("0x1?2e1", base=0) # e is not allowed in base > 10, use @ instead
221+
Traceback (most recent call last):
222+
...
223+
TypeError: unable to convert '0x1?2e1' to real interval
224+
sage: RIF("0x1?2p1", base=36)
225+
Traceback (most recent call last):
226+
...
227+
TypeError: unable to convert '0x1?2p1' to real interval
228+
"""
229+
cdef mpz_t error_part
230+
cdef mpfi_t error
231+
cdef mpfr_t radius, neg_radius
232+
cdef bytes int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string, optional_expo, tmp
233+
234+
match = NUMBER.fullmatch(s)
235+
if match is None:
236+
return 1
237+
int_part_string, base_prefix, frac_part_string, error_string, e, sci_expo_string = match.groups()
238+
239+
if (base > 10 or (base == 0 and base_prefix in (b'0X', b'0X'))) and e in (b'e', b'E'):
240+
return 1
241+
if base > 16 and e in (b'p', b'P'):
242+
return 1
243+
if base > 16 or not base_prefix:
244+
base_prefix = b''
245+
246+
if error_string:
247+
if mpz_init_set_str(error_part, error_string, 10):
248+
mpz_clear(error_part)
249+
return 1
250+
else:
251+
mpz_init_set_ui(error_part, 1)
252+
253+
optional_expo = e + sci_expo_string if e else b''
254+
if mpfi_set_str(x, int_part_string + b'.' + frac_part_string + optional_expo, base):
255+
mpz_clear(error_part)
256+
return 1
257+
258+
mpfr_init2(radius, mpfi_get_prec(x))
259+
tmp = base_prefix + (
260+
b'0.' + b'0'*(len(frac_part_string)-1) + b'1' + optional_expo
261+
if frac_part_string else
262+
b'1.' + optional_expo)
263+
# if base = 0:
264+
# when s = '-0x123.456@7', tmp = '0x0.001@7'
265+
# when s = '-0x123.@7', tmp = '0x1.@7'
266+
# if base = 36:
267+
# when s = '-0x123.456@7', tmp = '0.001@7'
268+
if mpfr_set_str(radius, tmp, base, MPFR_RNDU):
269+
mpfr_clear(radius)
270+
mpz_clear(error_part)
271+
return 1
272+
273+
mpfr_mul_z(radius, radius, error_part, MPFR_RNDU)
274+
mpz_clear(error_part)
275+
276+
mpfr_init2(neg_radius, mpfi_get_prec(x))
277+
mpfr_neg(neg_radius, radius, MPFR_RNDD)
278+
279+
mpfi_init2(error, mpfi_get_prec(x))
280+
mpfi_interv_fr(error, neg_radius, radius)
281+
mpfr_clear(radius)
282+
mpfr_clear(neg_radius)
283+
284+
mpfi_add(x, x, error)
285+
mpfi_clear(error)
286+
287+
return 0
288+
289+
48290
cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
49291
"""
50292
Convert any object ``x`` to an MPFI interval or a pair of
@@ -186,6 +428,11 @@ cdef int mpfi_set_sage(mpfi_ptr re, mpfi_ptr im, x, field, int base) except -1:
186428
if isinstance(x, unicode):
187429
x = x.encode("ascii")
188430
if isinstance(x, bytes):
431+
if b"?" in x:
432+
if _from_str_question_style(re, (<bytes>x).replace(b' ', b''), base):
433+
x = bytes_to_str(x)
434+
raise TypeError(f"unable to convert {x!r} to real interval")
435+
return return_real(im)
189436
s = (<bytes>x).replace(b'..', b',').replace(b' ', b'').replace(b'+infinity', b'@inf@').replace(b'-infinity', b'-@inf@')
190437
if mpfi_set_str(re, s, base):
191438
x = bytes_to_str(x)

0 commit comments

Comments
 (0)