Skip to content

Commit 49fdd15

Browse files
committed
Compute normalisation const of lattice DGS faster
I rewrote the code for `_normalisation_factor_zz`. It used to enumerate short vectors, which might be imprecise even for a short basis. The new method uses `pari.qfrep` to enumerate vectors of a bounded norm + speeds up convergence of series using Poisson summation. TODO: The code doesn't work for non-integral lattices (i.e. over QQ), fixable by rescaling. The code is incorrect for non self-dual lattices, fixable by implementing the correct formula.
1 parent 26f5a09 commit 49fdd15

File tree

1 file changed

+79
-19
lines changed

1 file changed

+79
-19
lines changed

src/sage/stats/distributions/discrete_gaussian_lattice.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@
5757
#*****************************************************************************/
5858

5959
from sage.functions.log import exp
60-
from sage.functions.other import ceil
6160
from sage.rings.real_mpfr import RealField
6261
from sage.rings.real_mpfr import RR
6362
from sage.rings.integer_ring import ZZ
6463
from sage.rings.rational_field import QQ
6564
from .discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
6665
from sage.structure.sage_object import SageObject
67-
from sage.matrix.constructor import matrix, identity_matrix
66+
from sage.misc.functional import sqrt
67+
from sage.symbolic.constants import pi
68+
from sage.matrix.constructor import matrix
6869
from sage.modules.free_module import FreeModule
6970
from sage.modules.free_module_element import vector
7071

@@ -156,7 +157,7 @@ def compute_precision(precision, sigma):
156157
157158
INPUT:
158159
159-
- ``precision`` - an integer `> 53` nor ``None``.
160+
- ``precision`` - an integer `>= 53` nor ``None``.
160161
- ``sigma`` - if ``precision`` is ``None`` then the precision of
161162
``sigma`` is used.
162163
@@ -185,19 +186,20 @@ def compute_precision(precision, sigma):
185186
precision = max(53, precision)
186187
return precision
187188

188-
def _normalisation_factor_zz(self, tau=3):
189+
def _normalisation_factor_zz(self, tau=None, prec=None):
189190
r"""
190-
This function returns an approximation of `∑_{x ∈ \ZZ^n}
191+
This function returns an approximation of `∑_{x ∈ B}
191192
\exp(-|x|_2^2/(2σ²))`, i.e. the normalisation factor such that the sum
192-
over all probabilities is 1 for `\ZZⁿ`.
193-
194-
If this ``self.B`` is not an identity matrix over `\ZZ` a
195-
``NotImplementedError`` is raised.
193+
over all probabilities is 1 for `B`, via Poisson summation.
196194
197195
INPUT:
198196
199-
- ``tau`` -- all vectors `v` with `|v|_∞ ≤ τ·σ` are enumerated
200-
(default: ``3``).
197+
- ``tau`` -- (default: ``None``) all vectors `v` with `|v|_2^2 ≤ τ·σ`
198+
are enumerated; if none is provided, enumerate vectors with
199+
increasing norm until the sum converges to given precision. For high
200+
dimension lattice, this is recommended.
201+
202+
- ``prec`` -- (default: ``None``) Passed to :meth:`compute_precision`
201203
202204
EXAMPLES::
203205
@@ -206,7 +208,7 @@ def _normalisation_factor_zz(self, tau=3):
206208
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
207209
sage: f = D.f
208210
sage: c = D._normalisation_factor_zz(); c
209-
15.528...
211+
15.7496..
210212
211213
sage: from collections import defaultdict
212214
sage: counter = defaultdict(Integer)
@@ -228,15 +230,73 @@ def _normalisation_factor_zz(self, tau=3):
228230
sage: while v not in counter: add_samples(1000)
229231
230232
sage: while abs(m*f(v)*1.0/c/counter[v] - 1.0) >= 0.2: add_samples(1000) # long time
233+
234+
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^8, 0.5)
235+
sage: D._normalisation_factor_zz(tau=3)
236+
3.1653...
237+
sage: D._normalisation_factor_zz()
238+
6.8249...
239+
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^8, 1000)
240+
sage: round(D._normalisation_factor_zz(prec=100))
241+
1558545456544038969634991553
242+
243+
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
244+
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
245+
246+
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1]])
247+
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
231248
"""
232-
if self.B != identity_matrix(ZZ, self.B.nrows()):
233-
raise NotImplementedError("This function is only implemented when B is an identity matrix.")
234249

235-
f = self.f
236-
n = self.B.ncols()
250+
# If σ > 1:
251+
# We use the Fourier transform g(t) of f(x) = exp(-k^2 / 2σ^2), but
252+
# taking the norm of vector t^2 as input, and with norm_factor factored.
253+
# If σ ≤ 1:
254+
# The formula in docstring converges quickly since it has -1 / σ^2 in
255+
# the exponent
256+
def f(x):
257+
# Fun fact: If you remove this R() and delay the call to return,
258+
# It might give an error due to precision error. For example,
259+
# RR(1 + 100 * exp(-5.0 * pi^2)) == 0
260+
if sigma > 1:
261+
return R(exp(-pi**2 * (2 * sigma**2) * x))
262+
return R(exp(-x / (2 * sigma**2)))
263+
264+
if self.B != 1:
265+
# TODO: Implement
266+
raise NotImplementedError("Implement")
267+
237268
sigma = self._sigma
238-
return sum(f(x) for x in _iter_vectors(n, -ceil(tau * sigma),
239-
ceil(tau * sigma)))
269+
prec = DiscreteGaussianDistributionLatticeSampler.compute_precision(
270+
prec, sigma
271+
)
272+
R = RealField(prec=prec)
273+
if sigma > 1:
274+
B = self.B
275+
# TODO: Take B dual
276+
raise NotImplementedError("oh no")
277+
norm_factor = (sigma * sqrt(2 * pi))**self.B.ncols()
278+
else:
279+
B = self.B
280+
norm_factor = 1
281+
282+
# qfrep computes theta series of a quadratic form, which is *half* the
283+
# generating function of number of vectors with given norm (and no 0)
284+
if tau is not None:
285+
freq = self.Q.__pari__().qfrep(tau * sigma, 0)
286+
res = R(1)
287+
for x, fq in enumerate(freq):
288+
res += 2 * ZZ(fq) * f(x + 1)
289+
return R(norm_factor * res)
290+
291+
res = R(1)
292+
bound = 0
293+
while True:
294+
bound += 1
295+
cnt = ZZ(self.Q.__pari__().qfrep(bound, 0)[bound - 1])
296+
inc = 2 * cnt * f(bound)
297+
if cnt > 0 and res == res + inc:
298+
return R(norm_factor * res)
299+
res += inc
240300

241301
def __init__(self, B, sigma=1, c=None, precision=None):
242302
r"""
@@ -262,7 +322,7 @@ def __init__(self, B, sigma=1, c=None, precision=None):
262322
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
263323
sage: f = D.f
264324
sage: c = D._normalisation_factor_zz(); c
265-
56.2162803067524
325+
56.5486677646162
266326
267327
sage: from collections import defaultdict
268328
sage: counter = defaultdict(Integer)

0 commit comments

Comments
 (0)