Skip to content

Commit f242f3c

Browse files
committed
Refactor code and isolate precomputation
1 parent d4ab577 commit f242f3c

File tree

1 file changed

+92
-48
lines changed

1 file changed

+92
-48
lines changed

src/sage/stats/distributions/discrete_gaussian_lattice.py

Lines changed: 92 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
236236
sage: while v not in counter:
237237
....: add_samples(1000)
238238
239-
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.2: # long time, needs sage.symbolic
239+
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.2: # long time, needs sage.symbolic
240240
....: add_samples(1000)
241241
242242
sage: D = DGL(ZZ^8, 0.5)
@@ -278,10 +278,10 @@ def f_or_hat(x):
278278
return R(exp(-x / (2 * sigma**2)))
279279

280280
if not self.is_spherical:
281-
# NOTE: This is only a poor approximation placeholder.
281+
# TODO: This is only a poor approximation placeholder.
282282
# It should be easy to implement, since the Fourier transform
283283
# is essentially the same, but I can't figure out how to
284-
# tweak the `.qfrep` call below correctly. TODO.
284+
# tweak the `.qfrep` call below correctly.
285285
from warnings import warn
286286
warn("Note: `_normalisation_factor_zz` has not been properly "\
287287
"implemented for non-spherical distributions.")
@@ -301,7 +301,7 @@ def f_or_hat(x):
301301
if self.is_spherical and not self._c_in_lattice:
302302
raise NotImplementedError("Lattice must contain 0 for now.")
303303

304-
if self.B.base_ring() not in ZZ:
304+
if self.B.base_ring() != ZZ:
305305
raise NotImplementedError("Lattice must be integral for now.")
306306

307307
sigma = self._sigma
@@ -384,14 +384,14 @@ def _randomise(self, v):
384384
"""
385385
return vector(ZZ, [DGI(self.r, c=vi)() for vi in v])
386386

387-
def __init__(self, B, sigma=1, c=None, r=None, precision=None):
387+
def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
388388
r"""
389389
Construct a discrete Gaussian sampler over the lattice `Λ(B)`
390390
with parameter ``sigma`` and center `c`.
391391
392392
INPUT:
393393
394-
- ``B`` -- a basis for the lattice, one of the following:
394+
- ``B`` -- a (row) basis for the lattice, one of the following:
395395
396396
- an integer matrix,
397397
- an object with a ``matrix()`` method, e.g. ``ZZ^n``, or
@@ -401,9 +401,10 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
401401
402402
- a real number `σ > 0` (spherical),
403403
- a positive definite matrix `Σ` (non-spherical), or
404-
- any matrix-like ``S``, equivalent to ``Σ = SSᵀ``
404+
- any matrix-like ``S``, equivalent to ``Σ = SSᵀ``, when
405+
``sigma_basis`` is set
405406
406-
- ``c`` -- (default: None) center `c`, any vector in `\ZZ^n` is
407+
- ``c`` -- (default: 0) center `c`, any vector in `\ZZ^n` is
407408
supported, but `c ∈ Λ(B)` is faster.
408409
409410
- ``r`` -- (default: None) rounding parameter `r` as defined in
@@ -413,6 +414,9 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
413414
414415
- ``precision`` -- bit precision `≥ 53`.
415416
417+
- ``sigma_basis`` -- (default: False) When set, ``sigma`` is treated as
418+
a basis, i.e. the covariance matrix is computed by ``Σ = SSᵀ``.
419+
416420
EXAMPLES::
417421
418422
sage: from sage.stats.all import DGL
@@ -459,8 +463,8 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
459463
460464
The non-spherical sampler supports offline computation to speed up
461465
sampling. This will be useful when changing the center `c` is supported.
462-
The difference is more significant for larger matrices. For 128x128 the
463-
author of this sentence see a 4x speedup (86s -> 20s).
466+
The difference is more significant for larger matrices. For 128x128 we
467+
observe a 4x speedup (86s -> 20s).
464468
465469
sage: D.offline_samples = []
466470
sage: T = 2**12
@@ -487,15 +491,15 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
487491
# Check if sigma is a (real) number or a scaled identity matrix
488492
self.is_spherical = True
489493
try:
490-
self._sigma = self._RR(sigma)
494+
self.sigma = self._RR(sigma)
491495
except TypeError:
492-
self._sigma = matrix(self._RR, sigma)
496+
self.sigma = matrix(self._RR, sigma)
493497
# Will it be "annoying" if a matrix Sigma has different behaviour
494498
# sometimes? There should be a parameter in the consrtuctor
495-
if self._sigma == self._sigma[0, 0]:
496-
self._sigma = self._RR(self._sigma[0, 0])
499+
if self.sigma == self.sigma[0, 0]:
500+
self.sigma = self._RR(self.sigma[0, 0])
497501
else:
498-
if not self._sigma.is_positive_definite():
502+
if not self.sigma.is_positive_definite():
499503
raise RuntimeError(f"Sigma(={self._sigma}) is not positive definite")
500504
self.is_spherical = False
501505

@@ -516,36 +520,36 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
516520
self._G = B.gram_schmidt()[0]
517521
self._c_in_lattice = False
518522

519-
try:
520-
c = vector(ZZ, self.n, c)
521-
except TypeError:
522-
try:
523-
c = vector(QQ, self.n, c)
524-
except TypeError:
525-
c = vector(self._RR, self.n, c)
523+
self.D = None
524+
self.VS = None
525+
self._c_mul_B_inv = None
526+
self.r = r
526527

527-
self._c = c
528+
self.c = c
528529

530+
def _precompute_data(self):
531+
r"""
532+
Precomputes basis data.
533+
"""
529534
if self.is_spherical:
530535
# deal with trivial case first, it is common
531536
if self._G == 1 and self.c == 0:
532537
self._c_in_lattice = True
533-
D = DGI(sigma=sigma)
538+
D = DGI(sigma=self.sigma)
534539
self.D = tuple([D for _ in range(self.B.nrows())])
535-
self.VS = FreeModule(ZZ, B.nrows())
536-
return
540+
self.VS = FreeModule(ZZ, self.B.nrows())
537541

538542
else:
539543
try:
540-
w = B.solve_left(c)
541-
if w in ZZ ** B.nrows():
544+
w = self.B.solve_left(self.c)
545+
if w in ZZ ** self.B.nrows():
542546
self._c_in_lattice = True
543547
D = []
544548
for i in range(self.B.nrows()):
545-
sigma_ = self._sigma / self._G[i].norm()
549+
sigma_ = self.sigma / self._G[i].norm()
546550
D.append(DGI(sigma=sigma_))
547551
self.D = tuple(D)
548-
self.VS = FreeModule(ZZ, B.nrows())
552+
self.VS = FreeModule(ZZ, self.B.nrows())
549553
except ValueError:
550554
pass
551555
else:
@@ -555,26 +559,27 @@ def __init__(self, B, sigma=1, c=None, r=None, precision=None):
555559

556560
# Offline samples of B⁻¹D₁
557561
self.offline_samples = []
558-
self.B_inv = B.inverse()
562+
self.B_inv = self.B.inverse()
559563
self.sigma_inv = self.sigma.inverse()
564+
self._c_mul_B_inv = self.c * self.B_inv
560565

561-
if r is None:
566+
if self.r is None:
562567
# Compute the maximal r such that (Sigma - r^2 * Q) > 0
563-
r = self._maximal_r() * 0.9999
564-
r = self._RR(r)
568+
self.r = self._maximal_r() * 0.9999
569+
self.r = self._RR(self.r)
565570

566-
Sigma2 = self._sigma - r**2 * self.Q
571+
Sigma2 = self._sigma - self.r**2 * self.Q
567572
try:
568-
self.r = r
569573
verbose(f"Computing Cholesky decomposition of a {Sigma2.dimensions()} matrix")
570574
self.B2 = Sigma2.cholesky().T
571575
self.B2_B_inv = self.B2 * self.B_inv
572576
except ValueError:
573577
raise ValueError("Σ₂ is not positive definite. Is your "\
574-
f"r(={r}) too large? It should be at most "\
578+
f"r(={self.r}) too large? It should be at most "\
575579
f"{self._maximal_r()}")
576580

577581

582+
578583
def __call__(self):
579584
r"""
580585
Return a new sample.
@@ -640,6 +645,25 @@ def sigma(self):
640645
"""
641646
return self._sigma
642647

648+
@sigma.setter
649+
def sigma(self, sigma):
650+
r"""
651+
Modifies center `σ`.
652+
653+
EXAMPLES::
654+
655+
sage: from sage.stats.all import DGL
656+
sage: D = DGL(ZZ^3, 3.0, c=(1,0,0))
657+
sage: D.c = (2, 0, 0)
658+
sage: D
659+
Discrete Gaussian sampler with Gaussian parameter σ = 3.00000000000000, c=(2, 0, 0) over lattice with basis
660+
<BLANKLINE>
661+
[1 0 0]
662+
[0 1 0]
663+
[0 0 1]
664+
"""
665+
self._sigma = sigma
666+
643667
@property
644668
def c(self):
645669
r"""
@@ -663,22 +687,43 @@ def c(self):
663687
return self._c
664688

665689
@c.setter
666-
def c(self, _):
690+
def c(self, c):
667691
r"""
668692
Modifies center `c`
669693
670694
EXAMPLES::
671695
672696
sage: from sage.stats.all import DGL
673697
sage: D = DGL(ZZ^3, 3.0, c=(1,0,0))
674-
sage: D.c = 5
675-
Traceback (most recent call last):
676-
...
677-
NotImplementedError: Modifying c is not yet supported!
698+
sage: D.c = (2, 0, 0)
699+
sage: D
700+
Discrete Gaussian sampler with Gaussian parameter σ = 3.00000000000000, c=(2, 0, 0) over lattice with basis
701+
<BLANKLINE>
702+
[1 0 0]
703+
[0 1 0]
704+
[0 0 1]
678705
"""
679-
# TODO: Isolate code to set `c` here, so that the offline part of
680-
# non-spherical sampling can be effectively utilised
681-
raise NotImplementedError("Modifying c is not yet supported!")
706+
if c is None:
707+
self._c = None
708+
return
709+
710+
if c == 0:
711+
c = vector(ZZ, self.n)
712+
else:
713+
try:
714+
c = vector(ZZ, self.n, c)
715+
except TypeError:
716+
try:
717+
c = vector(QQ, self.n, c)
718+
except TypeError:
719+
try:
720+
c = vector(self._RR, self.n, c)
721+
except TypeError:
722+
c = vector(self._RR, self.n)
723+
724+
self._c = c
725+
self._precompute_data()
726+
682727

683728
def __repr__(self):
684729
r"""
@@ -703,7 +748,6 @@ def __repr__(self):
703748
[0 1 0]
704749
[0 0 1]
705750
"""
706-
# beware of unicode character in ascii string !
707751
if self.is_spherical:
708752
sigma_str = f"σ = {self._sigma}"
709753
else:
@@ -749,7 +793,7 @@ def _call(self):
749793
Do not call this method directly, call :func:`DGL.__call__` instead.
750794
"""
751795
v = 0
752-
c, sigma, B = self._c, self._sigma, self.B
796+
c, sigma, B = self.c, self._sigma, self.B
753797

754798
m = self.B.nrows()
755799

@@ -805,7 +849,7 @@ def _call_non_spherical(self):
805849
"""
806850
if len(self.offline_samples) == 0:
807851
self.add_offline_samples()
808-
vec = self.c * self.B_inv - self.offline_samples.pop()
852+
vec = self._c_mul_B_inv - self.offline_samples.pop()
809853
return self._randomise(vec) * self.B
810854

811855

0 commit comments

Comments
 (0)