@@ -208,7 +208,7 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
208
208
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
209
209
sage: f = D.f
210
210
sage: c = D._normalisation_factor_zz(); c
211
- 15.7496..
211
+ 15.7496...
212
212
213
213
sage: from collections import defaultdict
214
214
sage: counter = defaultdict(Integer)
@@ -241,10 +241,19 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
241
241
1558545456544038969634991553
242
242
243
243
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
244
- sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
244
+ sage: D = DiscreteGaussianDistributionLatticeSampler(M, 1.7)
245
+ sage: D._normalisation_factor_zz() # long time
246
+ 7247.1975...
245
247
246
248
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1]])
247
249
sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3)
250
+ sage: D._normalisation_factor_zz()
251
+ Traceback (most recent call last):
252
+ ...
253
+ NotImplementedError: Basis must be a square matrix for now.
254
+
255
+ sage: c = vector([3, 7, 1])
256
+ sage: D = DiscreteGaussianDistributionLatticeSampler(M, 3, c = c)
248
257
"""
249
258
250
259
# If σ > 1:
@@ -261,39 +270,43 @@ def f(x):
261
270
return R (exp (- pi ** 2 * (2 * sigma ** 2 ) * x ))
262
271
return R (exp (- x / (2 * sigma ** 2 )))
263
272
264
- if self .B != 1 :
265
- # TODO: Implement
266
- raise NotImplementedError ("Implement" )
273
+ if self .B .nrows () != self .B .ncols ():
274
+ raise NotImplementedError ("Basis must be a square matrix for now." )
275
+
276
+ if not self ._c_in_lattice :
277
+ raise NotImplementedError ("Lattice must contain 0 for now." )
278
+
279
+ n = self .B .nrows ()
267
280
268
281
sigma = self ._sigma
269
282
prec = DiscreteGaussianDistributionLatticeSampler .compute_precision (
270
283
prec , sigma
271
284
)
272
285
R = RealField (prec = prec )
273
286
if sigma > 1 :
274
- B = self .B
275
- # TODO: Take B dual
276
- raise NotImplementedError ("oh no" )
277
- norm_factor = (sigma * sqrt (2 * pi ))** self .B .ncols ()
287
+ det = self .B .det ()
288
+ norm_factor = (sigma * sqrt (2 * pi ))** n / det
278
289
else :
279
- B = self . B
290
+ det = 1
280
291
norm_factor = 1
281
292
282
293
# qfrep computes theta series of a quadratic form, which is *half* the
283
294
# generating function of number of vectors with given norm (and no 0)
295
+ Q = self .B * self .B .T
284
296
if tau is not None :
285
- freq = self . Q .__pari__ ().qfrep (tau * sigma , 0 )
297
+ freq = Q .__pari__ ().qfrep (tau * sigma , 0 )
286
298
res = R (1 )
287
299
for x , fq in enumerate (freq ):
288
- res += 2 * ZZ (fq ) * f (x + 1 )
300
+ res += 2 * ZZ (fq ) * f (( x + 1 ) / det ** n )
289
301
return R (norm_factor * res )
290
302
291
303
res = R (1 )
292
304
bound = 0
305
+ # There might still be precision issue but whatever
293
306
while True :
294
307
bound += 1
295
- cnt = ZZ (self . Q .__pari__ ().qfrep (bound , 0 )[bound - 1 ])
296
- inc = 2 * cnt * f (bound )
308
+ cnt = ZZ (Q .__pari__ ().qfrep (bound , 0 )[bound - 1 ])
309
+ inc = 2 * cnt * f (bound / det ** n )
297
310
if cnt > 0 and res == res + inc :
298
311
return R (norm_factor * res )
299
312
res += inc
@@ -311,7 +324,12 @@ def __init__(self, B, sigma=1, c=None, precision=None):
311
324
- an object with a ``matrix()`` method, e.g. ``ZZ^n``, or
312
325
- an object where ``matrix(B)`` succeeds, e.g. a list of vectors.
313
326
314
- - ``sigma`` -- Gaussian parameter `σ>0`.
327
+ - ``sigma`` -- Gaussian parameter, one of the following:
328
+
329
+ - a real number `σ > 0`,
330
+ - a positive definite matrix `Σ`, or
331
+ - any matrix-like ``B``, equivalent to ``Σ = BBᵀ``
332
+
315
333
- ``c`` -- center `c`, any vector in `\ZZ^n` is supported, but `c ∈ Λ(B)` is faster.
316
334
- ``precision`` -- bit precision `≥ 53`.
317
335
@@ -322,7 +340,7 @@ def __init__(self, B, sigma=1, c=None, precision=None):
322
340
sage: D = DiscreteGaussianDistributionLatticeSampler(ZZ^n, sigma)
323
341
sage: f = D.f
324
342
sage: c = D._normalisation_factor_zz(); c
325
- 56.5486677646162
343
+ 56.5486677646...
326
344
327
345
sage: from collections import defaultdict
328
346
sage: counter = defaultdict(Integer)
@@ -357,7 +375,17 @@ def __init__(self, B, sigma=1, c=None, precision=None):
357
375
precision = DiscreteGaussianDistributionLatticeSampler .compute_precision (precision , sigma )
358
376
359
377
self ._RR = RealField (precision )
360
- self ._sigma = self ._RR (sigma )
378
+ # Check if sigma is a (real) number or a scaled identity matrix
379
+ self .is_spherical = True
380
+ try :
381
+ self ._sigma = self ._RR (sigma )
382
+ except TypeError as e :
383
+ print ("error:" , e )
384
+ self ._sigma = matrix (self ._RR , sigma )
385
+ if self ._sigma == self ._sigma [0 , 0 ]:
386
+ self ._sigma = self ._RR (self ._sigma [0 , 0 ])
387
+ else :
388
+ self .is_spherical = False
361
389
362
390
try :
363
391
B = matrix (B )
@@ -384,25 +412,32 @@ def __init__(self, B, sigma=1, c=None, precision=None):
384
412
385
413
self .f = lambda x : exp (- (vector (ZZ , B .ncols (), x ) - c ).norm () ** 2 / (2 * self ._sigma ** 2 ))
386
414
387
- # deal with trivial case first, it is common
388
- if self ._G == 1 and self ._c == 0 :
389
- self ._c_in_lattice = True
390
- D = DiscreteGaussianDistributionIntegerSampler (sigma = sigma )
391
- self .D = tuple ([D for _ in range (self .B .nrows ())])
392
- self .VS = FreeModule (ZZ , B .nrows ())
393
- return
394
-
395
- w = B .solve_left (c )
396
- if w in ZZ ** B .nrows ():
397
- self ._c_in_lattice = True
398
- D = []
399
- for i in range (self .B .nrows ()):
400
- sigma_ = self ._sigma / self ._G [i ].norm ()
401
- D .append (DiscreteGaussianDistributionIntegerSampler (sigma = sigma_ ))
402
- self .D = tuple (D )
403
- self .VS = FreeModule (ZZ , B .nrows ())
415
+ if self .is_spherical :
416
+ # deal with trivial case first, it is common
417
+ if self ._G == 1 and self ._c == 0 :
418
+ self ._c_in_lattice = True
419
+ D = DiscreteGaussianDistributionIntegerSampler (sigma = sigma )
420
+ self .D = tuple ([D for _ in range (self .B .nrows ())])
421
+ self .VS = FreeModule (ZZ , B .nrows ())
422
+ return
423
+
424
+ else :
425
+ self ._c_in_lattice = False
426
+ try :
427
+ w = B .solve_left (c )
428
+ if w in ZZ ** B .nrows ():
429
+ self ._c_in_lattice = True
430
+ D = []
431
+ for i in range (self .B .nrows ()):
432
+ sigma_ = self ._sigma / self ._G [i ].norm ()
433
+ D .append (DiscreteGaussianDistributionIntegerSampler (sigma = sigma_ ))
434
+ self .D = tuple (D )
435
+ self .VS = FreeModule (ZZ , B .nrows ())
436
+ except ValueError :
437
+ pass
404
438
else :
405
- self ._c_in_lattice = False
439
+ # TODO: Precompute basis of sqrt(sigma), change _str_
440
+ raise NotImplementedError
406
441
407
442
def __call__ (self ):
408
443
r"""
@@ -535,3 +570,14 @@ def _call(self):
535
570
c = c - z * B [i ]
536
571
v = v + z * B [i ]
537
572
return v
573
+
574
+ def _call_non_spherical (self ):
575
+ """
576
+ Non-spherical sampler
577
+
578
+ .. note::
579
+
580
+ Do not call this method directly, call :func:`DiscreteGaussianDistributionLatticeSampler.__call__` instead.
581
+ """
582
+ # TODO: Implement
583
+ raise NotImplementedError
0 commit comments