Skip to content

Commit 1a0c5e2

Browse files
author
Release Manager
committed
gh-38405: Call more general algorithm when lattice basis isn't trivial fixes #38400 - [x] I have linked a relevant issue or discussion. - [x] I have created tests covering the changes. - [x] I have updated the documentation and checked the documentation preview. URL: #38405 Reported by: Martin R. Albrecht Reviewer(s): Martin R. Albrecht, Matthias Köppe
2 parents 4083147 + 261730a commit 1a0c5e2

File tree

1 file changed

+31
-19
lines changed

1 file changed

+31
-19
lines changed

src/sage/stats/distributions/discrete_gaussian_lattice.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
248248
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
249249
sage: D = DGL(M, 1.7)
250250
sage: D._normalisation_factor_zz() # long time
251-
7247.1975...
251+
Traceback (most recent call last):
252+
...
253+
NotImplementedError: center must be at zero and basis must be trivial
252254
253255
sage: Sigma = Matrix(ZZ, [[5, -2, 4], [-2, 10, -5], [4, -5, 5]])
254256
sage: D = DGL(ZZ^3, Sigma, [7, 2, 5])
@@ -260,19 +262,19 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
260262
sage: D._normalisation_factor_zz()
261263
Traceback (most recent call last):
262264
...
263-
NotImplementedError: basis must be a square matrix for now
265+
NotImplementedError: basis must be a square matrix
264266
265267
sage: D = DGL(ZZ^3, c=(1/2, 0, 0))
266268
sage: D._normalisation_factor_zz()
267269
Traceback (most recent call last):
268270
...
269-
NotImplementedError: lattice must contain 0 for now
271+
NotImplementedError: center must be at zero and basis must be trivial
270272
271273
sage: D = DGL(Matrix(3, 3, 1/2))
272274
sage: D._normalisation_factor_zz()
273275
Traceback (most recent call last):
274276
...
275-
NotImplementedError: lattice must be integral for now
277+
NotImplementedError: lattice must be integral
276278
"""
277279
# If σ > 1:
278280
# We use the Fourier transform g(t) of f(x) = exp(-k^2 / 2σ^2), but
@@ -312,13 +314,13 @@ def f_or_hat(x):
312314
return sum(self.f((vector(u) + base) * self.B) for u in coords)
313315

314316
if self.B.nrows() != self.B.ncols():
315-
raise NotImplementedError("basis must be a square matrix for now")
316-
317-
if self.is_spherical and not self._c_in_lattice:
318-
raise NotImplementedError("lattice must contain 0 for now")
317+
raise NotImplementedError("basis must be a square matrix")
319318

320319
if self.B.base_ring() != ZZ:
321-
raise NotImplementedError("lattice must be integral for now")
320+
raise NotImplementedError("lattice must be integral")
321+
322+
if self.is_spherical and not self._c_in_lattice_and_lattice_trivial:
323+
raise NotImplementedError("center must be at zero and basis must be trivial")
322324

323325
sigma = self._sigma
324326
prec = DiscreteGaussianDistributionLatticeSampler.compute_precision(
@@ -583,7 +585,7 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
583585
self.B = B
584586
self.Q = B * B.T
585587
self._G = B.gram_schmidt()[0]
586-
self._c_in_lattice = False
588+
self._c_in_lattice_and_lattice_trivial = False
587589

588590
self.D = None
589591
self.VS = None
@@ -612,18 +614,19 @@ def _precompute_data(self):
612614
Do not call this method directly, it is called automatically from
613615
:func:`DiscreteGaussianDistributionLatticeSampler.__init__`.
614616
"""
617+
615618
if self.is_spherical:
616619
# deal with trivial case first, it is common
617-
if self._G == 1 and self._c == 0:
618-
self._c_in_lattice = True
620+
if self._c == 0 and self._G == 1:
621+
self._c_in_lattice_and_lattice_trivial = True
619622
D = DiscreteGaussianDistributionIntegerSampler(sigma=self._sigma)
620623
self.D = tuple([D for _ in range(self.B.nrows())])
621624
self.VS = FreeModule(ZZ, self.B.nrows())
622625

623626
else:
624627
w = self.B.solve_left(self._c)
625-
if w in ZZ ** self.B.nrows():
626-
self._c_in_lattice = True
628+
if w in ZZ ** self.B.nrows() and self._G == 1:
629+
self._c_in_lattice_and_lattice_trivial = True
627630
D = []
628631
for i in range(self.B.nrows()):
629632
sigma_ = self._sigma / self._G[i].norm()
@@ -673,11 +676,20 @@ def __call__(self):
673676
sage: mean_L = sum(L) / len(L) # long time
674677
sage: norm(mean_L.n() - D.c()) < 0.25 # long time
675678
True
679+
680+
sage: import numpy
681+
sage: M = matrix(ZZ, [[1,2],[0,1]])
682+
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(M, 20.0)
683+
sage: L = [D() for _ in range(2^12)] # long time
684+
sage: div = numpy.mean([abs(x) for x,y in L]) / numpy.mean([abs(y) for x,y, in L]) # long time
685+
sage: 0.9 < div < 1.1 # long time
686+
True
687+
676688
"""
677689
if not self.is_spherical:
678690
v = self._call_non_spherical()
679-
elif self._c_in_lattice:
680-
v = self._call_in_lattice()
691+
elif self._c_in_lattice_and_lattice_trivial:
692+
v = self._call_simple()
681693
else:
682694
v = self._call()
683695
v.set_immutable()
@@ -807,14 +819,14 @@ def __repr__(self):
807819
sigma_str = f"Σ =\n{self._sigma}"
808820
return f"Discrete Gaussian sampler with Gaussian parameter {sigma_str}, c={self._c} over lattice with basis\n\n{self.B}"
809821

810-
def _call_in_lattice(self):
822+
def _call_simple(self):
811823
r"""
812-
Return a new sample assuming `c \in \Lambda(B)`.
824+
Return a new sample assuming `c \in \Lambda(B)` and `B^* = 1`.
813825
814826
EXAMPLES::
815827
816828
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(ZZ^3, 3.0, c=(1,0,0))
817-
sage: L = [D._call_in_lattice() for _ in range(2^12)]
829+
sage: L = [D._call_simple() for _ in range(2^12)]
818830
sage: mean_L = sum(L) / len(L)
819831
sage: norm(mean_L.n() - D.c()) < 0.25
820832
True

0 commit comments

Comments
 (0)