Skip to content

Commit 1962918

Browse files
mklssMichael Klooss
authored andcommitted
Add reseed_rng option to p_iter_fork
1 parent b9e396a commit 1962918

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/sage/parallel/decorate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def parallel(p_iter='fork', ncpus=None, **kwds):
308308
- ``ncpus`` -- integer; maximal number of subprocesses to use at the same time
309309
- ``timeout`` -- number of seconds until each subprocess is killed (only supported
310310
by ``'fork'``; zero means not at all)
311+
- ``reseed_rng``: reseed the rng in each subprocess
311312
312313
.. warning::
313314
@@ -398,6 +399,16 @@ def parallel(p_iter='fork', ncpus=None, **kwds):
398399
sage: Foo.square_classmethod(3)
399400
9
400401
402+
403+
By default, all subprocesses use the same random seed and therefore the same deterministic randomness.
404+
For functions that should be randomized, we can reseed the random seed in each subprocess::
405+
406+
sage: @parallel(reseed_rng = True)
407+
....: def unif(n): return ZZ.random_element(x = 0, y = n)
408+
sage: set_random_seed(42)
409+
sage: sorted(unif([1000]*3))
410+
[(((1000,), {}), 444), (((1000,), {}), 597), (((1000,), {}), 640)]
411+
401412
.. warning::
402413
403414
Currently, parallel methods do not work with the

src/sage/parallel/use_fork.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from sage.interfaces.process import ContainChildren
2020
from sage.misc.timing import walltime
2121

22+
from sage.misc.randstate import set_random_seed
23+
from sage.misc.prandom import getrandbits
2224

2325
class WorkerData:
2426
"""
@@ -68,6 +70,7 @@ class p_iter_fork:
6870
about what the iterator does (e.g., killing subprocesses)
6971
- ``reset_interfaces`` -- boolean (default: ``True``); whether to reset all
7072
pexpect interfaces
73+
- ``reseed_rng`` -- booolean (default: ``False``); whether or not to reseed the rng in the subprocesses
7174
7275
EXAMPLES::
7376
@@ -80,7 +83,7 @@ class p_iter_fork:
8083
sage: X.verbose
8184
False
8285
"""
83-
def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True):
86+
def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True, reseed_rng=False):
8487
"""
8588
Create a ``fork()``-based parallel iterator.
8689
@@ -103,6 +106,8 @@ def __init__(self, ncpus, timeout=0, verbose=False, reset_interfaces=True):
103106
self.timeout = float(timeout) # require a float
104107
self.verbose = verbose
105108
self.reset_interfaces = reset_interfaces
109+
self.reseed_rng = reseed_rng
110+
self.worker_seed = None
106111

107112
def __call__(self, f, inputs):
108113
"""
@@ -148,8 +153,6 @@ def __call__(self, f, inputs):
148153
sage: list(Polygen([QQ,QQ]))
149154
[(((Rational Field,), {}), x), (((Rational Field,), {}), x)]
150155
"""
151-
n = self.ncpus
152-
v = list(inputs)
153156
import os
154157
import sys
155158
import signal
@@ -158,21 +161,28 @@ def __call__(self, f, inputs):
158161
dir = tmp_dir()
159162
timeout = self.timeout
160163

164+
n = self.ncpus
165+
inputs = list(inputs)
166+
if self.reseed_rng:
167+
seeds = [getrandbits(512) for _ in range(0, len(inputs))]
168+
vs = list(zip(inputs, seeds))
169+
else:
170+
vs = list(zip(inputs, [None]*len(inputs)))
161171
workers = {}
162172
try:
163-
while v or workers:
173+
while vs or workers:
164174
# Spawn up to n subprocesses
165-
while v and len(workers) < n:
166-
v0 = v.pop(0) # Input value for the next subprocess
175+
while vs and len(workers) < n:
176+
(v0, seed0) = vs.pop(0) # Input value and seed for the next subprocess
167177
with ContainChildren():
168178
pid = os.fork()
169179
# The way fork works is that pid returns the
170180
# nonzero pid of the subprocess for the master
171181
# process and returns 0 for the subprocess.
172182
if not pid:
173183
# This is the subprocess.
184+
self.worker_seed = seed0 if self.reseed_rng else None
174185
self._subprocess(f, dir, *v0)
175-
176186
workers[pid] = WorkerData(v0)
177187

178188
if len(workers) > 0:
@@ -304,6 +314,11 @@ def _subprocess(self, f, dir, args, kwds={}):
304314
else:
305315
invalidate_all()
306316

317+
# Reseed rng, if requested.
318+
if self.reseed_rng == True:
319+
set_random_seed(self.worker_seed)
320+
321+
307322
# Now evaluate the function f.
308323
value = f(*args, **kwds)
309324

0 commit comments

Comments
 (0)