Skip to content

Commit 7e35966

Browse files
committed
Compute _normalisation_factor_zz for sigma > 0
The formula has been cross-checked with an external C program for a smaller lattice.
1 parent 49fdd15 commit 7e35966

File tree

1 file changed

+81
-35
lines changed

1 file changed

+81
-35
lines changed

src/sage/stats/distributions/discrete_gaussian_lattice.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
208208
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
209209
sage: f = D.f
210210
sage: c = D._normalisation_factor_zz(); c
211-
15.7496..
211+
15.7496...
212212
213213
sage: from collections import defaultdict
214214
sage: counter = defaultdict(Integer)
@@ -241,10 +241,19 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
241241
1558545456544038969634991553
242242
243243
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
244-
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
244+
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 1.7)
245+
sage: D._normalisation_factor_zz() # long time
246+
7247.1975...
245247
246248
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1]])
247249
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
250+
sage: D._normalisation_factor_zz()
251+
Traceback (most recent call last):
252+
...
253+
NotImplementedError: Basis must be a square matrix for now.
254+
255+
sage: c = vector([3, 7, 1])
256+
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3, c = c)
248257
"""
249258

250259
# If σ > 1:
@@ -261,39 +270,43 @@ def f(x):
261270
return R(exp(-pi**2 * (2 * sigma**2) * x))
262271
return R(exp(-x / (2 * sigma**2)))
263272

264-
if self.B != 1:
265-
# TODO: Implement
266-
raise NotImplementedError("Implement")
273+
if self.B.nrows() != self.B.ncols():
274+
raise NotImplementedError("Basis must be a square matrix for now.")
275+
276+
if not self._c_in_lattice:
277+
raise NotImplementedError("Lattice must contain 0 for now.")
278+
279+
n = self.B.nrows()
267280

268281
sigma = self._sigma
269282
prec = DiscreteGaussianDistributionLatticeSampler.compute_precision(
270283
prec, sigma
271284
)
272285
R = RealField(prec=prec)
273286
if sigma > 1:
274-
B = self.B
275-
# TODO: Take B dual
276-
raise NotImplementedError("oh no")
277-
norm_factor = (sigma * sqrt(2 * pi))**self.B.ncols()
287+
det = self.B.det()
288+
norm_factor = (sigma * sqrt(2 * pi))**n / det
278289
else:
279-
B = self.B
290+
det = 1
280291
norm_factor = 1
281292

282293
# qfrep computes theta series of a quadratic form, which is *half* the
283294
# generating function of number of vectors with given norm (and no 0)
295+
Q = self.B * self.B.T
284296
if tau is not None:
285-
freq = self.Q.__pari__().qfrep(tau * sigma, 0)
297+
freq = Q.__pari__().qfrep(tau * sigma, 0)
286298
res = R(1)
287299
for x, fq in enumerate(freq):
288-
res += 2 * ZZ(fq) * f(x + 1)
300+
res += 2 * ZZ(fq) * f((x + 1) / det**n)
289301
return R(norm_factor * res)
290302

291303
res = R(1)
292304
bound = 0
305+
# There might still be precision issue but whatever
293306
while True:
294307
bound += 1
295-
cnt = ZZ(self.Q.__pari__().qfrep(bound, 0)[bound - 1])
296-
inc = 2 * cnt * f(bound)
308+
cnt = ZZ(Q.__pari__().qfrep(bound, 0)[bound - 1])
309+
inc = 2 * cnt * f(bound / det**n)
297310
if cnt > 0 and res == res + inc:
298311
return R(norm_factor * res)
299312
res += inc
@@ -311,7 +324,12 @@ def __init__(self, B, sigma=1, c=None, precision=None):
311324
- an object with a ``matrix()`` method, e.g. ``ZZ^n``, or
312325
- an object where ``matrix(B)`` succeeds, e.g. a list of vectors.
313326
314-
- ``sigma`` -- Gaussian parameter `σ>0`.
327+
- ``sigma`` -- Gaussian parameter, one of the following:
328+
329+
- a real number `σ > 0`,
330+
- a positive definite matrix `Σ`, or
331+
- any matrix-like ``B``, equivalent to ``Σ = BBᵀ``
332+
315333
- ``c`` -- center `c`, any vector in `\ZZ^n` is supported, but `c ∈ Λ(B)` is faster.
316334
- ``precision`` -- bit precision `≥ 53`.
317335
@@ -322,7 +340,7 @@ def __init__(self, B, sigma=1, c=None, precision=None):
322340
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
323341
sage: f = D.f
324342
sage: c = D._normalisation_factor_zz(); c
325-
56.5486677646162
343+
56.5486677646...
326344
327345
sage: from collections import defaultdict
328346
sage: counter = defaultdict(Integer)
@@ -357,7 +375,17 @@ def __init__(self, B, sigma=1, c=None, precision=None):
357375
precision = DiscreteGaussianDistributionLatticeSampler.compute_precision(precision, sigma)
358376

359377
self._RR = RealField(precision)
360-
self._sigma = self._RR(sigma)
378+
# Check if sigma is a (real) number or a scaled identity matrix
379+
self.is_spherical = True
380+
try:
381+
self._sigma = self._RR(sigma)
382+
except TypeError as e:
383+
print("error:", e)
384+
self._sigma = matrix(self._RR, sigma)
385+
if self._sigma == self._sigma[0, 0]:
386+
self._sigma = self._RR(self._sigma[0, 0])
387+
else:
388+
self.is_spherical = False
361389

362390
try:
363391
B = matrix(B)
@@ -384,25 +412,32 @@ def __init__(self, B, sigma=1, c=None, precision=None):
384412

385413
self.f = lambda x: exp(-(vector(ZZ, B.ncols(), x) - c).norm() ** 2 / (2 * self._sigma ** 2))
386414

387-
# deal with trivial case first, it is common
388-
if self._G == 1 and self._c == 0:
389-
self._c_in_lattice = True
390-
D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)
391-
self.D = tuple([D for _ in range(self.B.nrows())])
392-
self.VS = FreeModule(ZZ, B.nrows())
393-
return
394-
395-
w = B.solve_left(c)
396-
if w in ZZ ** B.nrows():
397-
self._c_in_lattice = True
398-
D = []
399-
for i in range(self.B.nrows()):
400-
sigma_ = self._sigma / self._G[i].norm()
401-
D.append(DiscreteGaussianDistributionIntegerSampler(sigma=sigma_))
402-
self.D = tuple(D)
403-
self.VS = FreeModule(ZZ, B.nrows())
415+
if self.is_spherical:
416+
# deal with trivial case first, it is common
417+
if self._G == 1 and self._c == 0:
418+
self._c_in_lattice = True
419+
D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)
420+
self.D = tuple([D for _ in range(self.B.nrows())])
421+
self.VS = FreeModule(ZZ, B.nrows())
422+
return
423+
424+
else:
425+
self._c_in_lattice = False
426+
try:
427+
w = B.solve_left(c)
428+
if w in ZZ ** B.nrows():
429+
self._c_in_lattice = True
430+
D = []
431+
for i in range(self.B.nrows()):
432+
sigma_ = self._sigma / self._G[i].norm()
433+
D.append(DiscreteGaussianDistributionIntegerSampler(sigma=sigma_))
434+
self.D = tuple(D)
435+
self.VS = FreeModule(ZZ, B.nrows())
436+
except ValueError:
437+
pass
404438
else:
405-
self._c_in_lattice = False
439+
# TODO: Precompute basis of sqrt(sigma), change _str_
440+
raise NotImplementedError
406441

407442
def __call__(self):
408443
r"""
@@ -535,3 +570,14 @@ def _call(self):
535570
c = c - z * B[i]
536571
v = v + z * B[i]
537572
return v
573+
574+
def _call_non_spherical(self):
575+
"""
576+
Non-spherical sampler
577+
578+
.. note::
579+
580+
Do not call this method directly, call :func:`DiscreteGaussianDistributionLatticeSampler.__call__` instead.
581+
"""
582+
# TODO: Implement
583+
raise NotImplementedError

0 commit comments

Comments
 (0)