Skip to content

Commit f0581cb

Browse files
author
Release Manager
committed
gh-39025: Add reseed_rng option to p_iter_fork This PR implements a convenience feature for the `@parallel` decorator, or more precisely, for `p_iter_fork`. By setting `reseed_rng=True` the random number generator is reset in each subprocess by running `set_random_seed(self.worker_seed)`. It is ensured that the worker seeds are deterministically derived and thus, the results are always reproducible. The new behaviour is useful for running probabilistic experiments and taking advantage of multiple cores. The PR adds one "doc test". It would make sense to have additional (randomized) tests to assert that future modifications don't break the reproducibility. But I failed to locate a comprehensive test suite for sage/parallel, in particular, no existing tests for reproducibility of parallel computations. The `sage/tests` folder seems to contain other non-specific tests? Coding style w.r.t. <https://doc.sagemath.org/html/en/reference/misc/sag e/misc/randstate.html> - The documentation said `with seed(worker_seed)` should be used. I did not use it, because the current implementation reseeds the whole subprocess once and for all. - The documentation says NTL does not reproduce. This is a very surprising and unfortunate state of affairs. I expect that `reseed_rng` inherits this caveat. I did not make it explicit in the doc string, because if the problem occurs, then also `reseed_rng=False` should have it, but there is no mention of it. ### 📝 Checklist - [x] The title is concise and informative. - [x] The description explains in detail what this PR is about. - [x] I have linked a relevant issue or discussion. - [x] I have created tests covering the changes. - [x] I have updated the documentation and checked the documentation preview. The doc tests passed for me, but preview via running `sage --docbuild tutorial html` gives me `ImportError: cannot import name count_all_local_good_types_normal_form' from 'sage.quadratic_forms.count_local_2'` so doc tests don't work although that has probably nothing to with this PR. URL: #39025 Reported by: mklss Reviewer(s): Kwankyu Lee
2 parents 290b261 + 8518218 commit f0581cb

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

src/sage/parallel/decorate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def parallel(p_iter='fork', ncpus=None, **kwds):
308308
- ``ncpus`` -- integer; maximal number of subprocesses to use at the same time
309309
- ``timeout`` -- number of seconds until each subprocess is killed (only supported
310310
by ``'fork'``; zero means not at all)
311+
- ``reseed_rng``: reseed the rng (random number generator) in each subprocess
311312
312313
.. warning::
313314
@@ -398,6 +399,15 @@ def parallel(p_iter='fork', ncpus=None, **kwds):
398399
sage: Foo.square_classmethod(3)
399400
9
400401
402+
By default, all subprocesses use the same random seed and therefore the same deterministic randomness.
403+
For functions that should be randomized, we can reseed the random seed in each subprocess::
404+
405+
sage: @parallel(reseed_rng=True)
406+
....: def unif(n): return ZZ.random_element(x=0, y=n)
407+
sage: set_random_seed(42)
408+
sage: sorted(unif([1000]*3)) # random
409+
[(((1000,), {}), 444), (((1000,), {}), 597), (((1000,), {}), 640)]
410+
401411
.. warning::
402412
403413
Currently, parallel methods do not work with the

src/sage/parallel/use_fork.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from sage.interfaces.process import ContainChildren
2020
from sage.misc.timing import walltime
2121

22+
from sage.misc.randstate import set_random_seed
23+
from sage.misc.prandom import getrandbits
24+
2225

2326
class WorkerData:
2427
"""
@@ -68,6 +71,8 @@ class p_iter_fork:
6871
about what the iterator does (e.g., killing subprocesses)
6972
- ``reset_interfaces`` -- boolean (default: ``True``); whether to reset all
7073
pexpect interfaces
74+
- ``reseed_rng`` -- boolean (default: ``False``); whether or not to reseed
75+
the rng in the subprocesses
7176
7277
EXAMPLES::
7378
@@ -80,7 +85,7 @@ class p_iter_fork:
8085
sage: X.verbose
8186
False
8287
"""
83-
def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True):
88+
def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True, reseed_rng=False):
8489
"""
8590
Create a ``fork()``-based parallel iterator.
8691
@@ -103,6 +108,8 @@ def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True):
103108
self.timeout = float(timeout) # require a float
104109
self.verbose = verbose
105110
self.reset_interfaces = reset_interfaces
111+
self.reseed_rng = reseed_rng
112+
self.worker_seed = None
106113

107114
def __call__(self, f, inputs):
108115
"""
@@ -148,8 +155,6 @@ def __call__(self, f, inputs):
148155
sage: list(Polygen([QQ,QQ]))
149156
[(((Rational Field,), {}), x), (((Rational Field,), {}), x)]
150157
"""
151-
n = self.ncpus
152-
v = list(inputs)
153158
import os
154159
import sys
155160
import signal
@@ -158,21 +163,29 @@ def __call__(self, f, inputs):
158163
dir = tmp_dir()
159164
timeout = self.timeout
160165

166+
n = self.ncpus
167+
inputs = list(inputs)
168+
if self.reseed_rng:
169+
seeds = [getrandbits(512) for _ in range(len(inputs))]
170+
vs = list(zip(inputs, seeds))
171+
else:
172+
vs = list(zip(inputs, [None]*len(inputs)))
161173
workers = {}
162174
try:
163-
while v or workers:
175+
while vs or workers:
164176
# Spawn up to n subprocesses
165-
while v and len(workers) < n:
166-
v0 = v.pop(0) # Input value for the next subprocess
177+
while vs and len(workers) < n:
178+
v0, seed0 = vs.pop(0) # Input value and seed for the next subprocess
167179
with ContainChildren():
168180
pid = os.fork()
169181
# The way fork works is that pid returns the
170182
# nonzero pid of the subprocess for the master
171183
# process and returns 0 for the subprocess.
172184
if not pid:
173185
# This is the subprocess.
186+
if self.reseed_rng:
187+
self.worker_seed = seed0
174188
self._subprocess(f, dir, *v0)
175-
176189
workers[pid] = WorkerData(v0)
177190

178191
if len(workers) > 0:
@@ -304,6 +317,10 @@ def _subprocess(self, f, dir, args, kwds={}):
304317
else:
305318
invalidate_all()
306319

320+
# Reseed rng, if requested.
321+
if self.reseed_rng:
322+
set_random_seed(self.worker_seed)
323+
307324
# Now evaluate the function f.
308325
value = f(*args, **kwds)
309326

0 commit comments

Comments
 (0)