Skip to content

Commit 44a0d2d

Browse files
author
Release Manager
committed
gh-36830: improved integer vectors efficiency -Enhancement Fixes #36816 1) **Added** the **cardinality function** using stars and bars to the **IntegerVector_nk** class 2) Fixed the errors in IntegerVectors_n and IntegerVectors_k now **IntegerVectors_n(0)cardinality() and IntegerVectors_k(0).cardinality() returns 1 and both of them are Finite EnumeratedSets** URL: #36830 Reported by: Aman Moon Reviewer(s): Aman Moon, Jukka Kohonen, Martin Rubey
2 parents 5d0093c + dedee8d commit 44a0d2d

File tree

1 file changed

+232
-26
lines changed

1 file changed

+232
-26
lines changed

src/sage/combinat/integer_vector.py

Lines changed: 232 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from sage.rings.integer import Integer
4949

5050

51-
def is_gale_ryser(r,s):
51+
def is_gale_ryser(r, s):
5252
r"""
5353
Tests whether the given sequences satisfy the condition
5454
of the Gale-Ryser theorem.
@@ -314,20 +314,20 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
314314
"""
315315
from sage.matrix.constructor import matrix
316316

317-
if not is_gale_ryser(p1,p2):
317+
if not is_gale_ryser(p1, p2):
318318
return False
319319

320-
if algorithm == "ryser": # ryser's algorithm
320+
if algorithm == "ryser": # ryser's algorithm
321321
from sage.combinat.permutation import Permutation
322322

323323
# Sorts the sequences if they are not, and remembers the permutation
324324
# applied
325-
tmp = sorted(enumerate(p1), reverse=True, key=lambda x:x[1])
325+
tmp = sorted(enumerate(p1), reverse=True, key=lambda x: x[1])
326326
r = [x[1] for x in tmp]
327327
r_permutation = [x-1 for x in Permutation([x[0]+1 for x in tmp]).inverse()]
328328
m = len(r)
329329

330-
tmp = sorted(enumerate(p2), reverse=True, key=lambda x:x[1])
330+
tmp = sorted(enumerate(p2), reverse=True, key=lambda x: x[1])
331331
s = [x[1] for x in tmp]
332332
s_permutation = [x-1 for x in Permutation([x[0]+1 for x in tmp]).inverse()]
333333

@@ -340,12 +340,12 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
340340
k = i + 1
341341
while k < m and r[i] == r[k]:
342342
k += 1
343-
if t >= k - i: # == number rows of the same length
343+
if t >= k - i: # == number rows of the same length
344344
for j in range(i, k):
345345
r[j] -= 1
346346
c[j] = 1
347347
t -= k - i
348-
else: # Remove the t last rows of that length
348+
else: # Remove the t last rows of that length
349349
for j in range(k-t, k):
350350
r[j] -= 1
351351
c[j] = 1
@@ -366,17 +366,17 @@ def gale_ryser_theorem(p1, p2, algorithm="gale",
366366
k1, k2 = len(p1), len(p2)
367367
p = MixedIntegerLinearProgram(solver=solver)
368368
b = p.new_variable(binary=True)
369-
for (i,c) in enumerate(p1):
370-
p.add_constraint(p.sum([b[i,j] for j in range(k2)]) == c)
371-
for (i,c) in enumerate(p2):
372-
p.add_constraint(p.sum([b[j,i] for j in range(k1)]) == c)
369+
for (i, c) in enumerate(p1):
370+
p.add_constraint(p.sum([b[i, j] for j in range(k2)]) == c)
371+
for (i, c) in enumerate(p2):
372+
p.add_constraint(p.sum([b[j, i] for j in range(k1)]) == c)
373373
p.set_objective(None)
374374
p.solve()
375375
b = p.get_values(b, convert=ZZ, tolerance=integrality_tolerance)
376376
M = [[0]*k2 for i in range(k1)]
377377
for i in range(k1):
378378
for j in range(k2):
379-
M[i][j] = b[i,j]
379+
M[i][j] = b[i, j]
380380
return matrix(M)
381381

382382
else:
@@ -780,6 +780,43 @@ def __contains__(self, x):
780780
return False
781781
return True
782782

783+
def _unrank_helper(self, x, rtn):
784+
"""
785+
Return the element at rank ``x`` by iterating through all integer vectors beginning with ``rtn``.
786+
787+
INPUT:
788+
789+
- ``x`` - a nonnegative integer
790+
- ``rtn`` - a list of nonnegative integers
791+
792+
793+
EXAMPLES::
794+
795+
sage: IV = IntegerVectors(k=5)
796+
sage: IV._unrank_helper(10, [2,0,0,0,0])
797+
[1, 0, 0, 0, 1]
798+
799+
sage: IV = IntegerVectors(n=7)
800+
sage: IV._unrank_helper(100, [7,0,0,0])
801+
[2, 0, 0, 5]
802+
803+
sage: IV = IntegerVectors(n=12, k=7)
804+
sage: IV._unrank_helper(1000, [12,0,0,0,0,0,0])
805+
[5, 3, 1, 1, 1, 1, 0]
806+
"""
807+
ptr = 0
808+
while True:
809+
current_rank = self.rank(rtn)
810+
if current_rank < x:
811+
rtn[ptr+1] = rtn[ptr]
812+
rtn[ptr] = 0
813+
ptr += 1
814+
elif current_rank > x:
815+
rtn[ptr] -= 1
816+
rtn[ptr-1] += 1
817+
else:
818+
return self._element_constructor_(rtn)
819+
783820

784821
class IntegerVectors_all(UniqueRepresentation, IntegerVectors):
785822
"""
@@ -839,7 +876,10 @@ def __init__(self, n):
839876
sage: TestSuite(IV).run()
840877
"""
841878
self.n = n
842-
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
879+
if self.n == 0:
880+
IntegerVectors.__init__(self, category=EnumeratedSets())
881+
else:
882+
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
843883

844884
def _repr_(self):
845885
"""
@@ -898,6 +938,68 @@ def __contains__(self, x):
898938
return False
899939
return sum(x) == self.n
900940

941+
def rank(self, x):
942+
"""
943+
Return the rank of a given element.
944+
945+
INPUT:
946+
947+
- ``x`` -- a list with ``sum(x) == n``
948+
949+
EXAMPLES::
950+
951+
sage: IntegerVectors(n=5).rank([5,0])
952+
1
953+
sage: IntegerVectors(n=5).rank([3,2])
954+
3
955+
"""
956+
if sum(x) != self.n:
957+
raise ValueError("argument is not a member of IntegerVectors({},{})".format(self.n, None))
958+
959+
n, k, s = self.n, len(x), 0
960+
r = binomial(k + n - 1, n + 1)
961+
for i in range(k - 1):
962+
s += x[k - 1 - i]
963+
r += binomial(s + i, i + 1)
964+
return r
965+
966+
def unrank(self, x):
967+
"""
968+
Return the element at given rank x.
969+
970+
INPUT:
971+
972+
- ``x`` -- an integer.
973+
974+
EXAMPLES::
975+
976+
sage: IntegerVectors(n=5).unrank(2)
977+
[4, 1]
978+
sage: IntegerVectors(n=10).unrank(10)
979+
[1, 9]
980+
"""
981+
rtn = [self.n]
982+
while self.rank(rtn) <= x:
983+
rtn.append(0)
984+
rtn.pop()
985+
986+
return IntegerVectors._unrank_helper(self, x, rtn)
987+
988+
def cardinality(self):
989+
"""
990+
Return the cardinality of ``self``.
991+
992+
EXAMPLES::
993+
994+
sage: IntegerVectors(n=0).cardinality()
995+
1
996+
sage: IntegerVectors(n=10).cardinality()
997+
+Infinity
998+
"""
999+
if self.n == 0:
1000+
return Integer(1)
1001+
return PlusInfinity()
1002+
9011003

9021004
class IntegerVectors_k(UniqueRepresentation, IntegerVectors):
9031005
"""
@@ -912,7 +1014,10 @@ def __init__(self, k):
9121014
sage: TestSuite(IV).run()
9131015
"""
9141016
self.k = k
915-
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
1017+
if self.k == 0:
1018+
IntegerVectors.__init__(self, category=EnumeratedSets())
1019+
else:
1020+
IntegerVectors.__init__(self, category=InfiniteEnumeratedSets())
9161021

9171022
def _repr_(self):
9181023
"""
@@ -968,6 +1073,75 @@ def __contains__(self, x):
9681073
return False
9691074
return len(x) == self.k
9701075

1076+
def rank(self, x):
1077+
"""
1078+
Return the rank of a given element.
1079+
1080+
INPUT:
1081+
1082+
- ``x`` -- a list with ``len(x) == k``
1083+
1084+
EXAMPLES::
1085+
1086+
sage: IntegerVectors(k=5).rank([0,0,0,0,0])
1087+
0
1088+
sage: IntegerVectors(k=5).rank([1,1,0,0,0])
1089+
7
1090+
"""
1091+
if len(x) != self.k:
1092+
raise ValueError("argument is not a member of IntegerVectors({},{})".format(None, self.k))
1093+
1094+
n, k, s = sum(x), self.k, 0
1095+
r = binomial(n + k - 1, k)
1096+
for i in range(k - 1):
1097+
s += x[k - 1 - i]
1098+
r += binomial(s + i, i + 1)
1099+
return r
1100+
1101+
def unrank(self, x):
1102+
"""
1103+
Return the element at given rank x.
1104+
1105+
INPUT:
1106+
1107+
- ``x`` -- an integer such that x < self.cardinality()``
1108+
1109+
EXAMPLES::
1110+
1111+
sage: IntegerVectors(k=5).unrank(10)
1112+
[1, 0, 0, 0, 1]
1113+
sage: IntegerVectors(k=5).unrank(15)
1114+
[0, 0, 2, 0, 0]
1115+
sage: IntegerVectors(k=0).unrank(0)
1116+
[]
1117+
"""
1118+
if self.k == 0 and x != 0:
1119+
raise IndexError(f"Index {x} is out of range for the IntegerVector.")
1120+
rtn = [0]*self.k
1121+
if self.k == 0 and x == 0:
1122+
return rtn
1123+
1124+
while self.rank(rtn) <= x:
1125+
rtn[0] += 1
1126+
rtn[0] -= 1
1127+
1128+
return IntegerVectors._unrank_helper(self, x, rtn)
1129+
1130+
def cardinality(self):
1131+
"""
1132+
Return the cardinality of ``self``.
1133+
1134+
EXAMPLES::
1135+
1136+
sage: IntegerVectors(k=0).cardinality()
1137+
1
1138+
sage: IntegerVectors(k=10).cardinality()
1139+
+Infinity
1140+
"""
1141+
if self.k == 0:
1142+
return Integer(1)
1143+
return PlusInfinity()
1144+
9711145

9721146
class IntegerVectors_nk(UniqueRepresentation, IntegerVectors):
9731147
"""
@@ -1010,11 +1184,11 @@ def _list_rec(self, n, k):
10101184
res = []
10111185

10121186
if k == 1:
1013-
return [ (n, ) ]
1187+
return [(n, )]
10141188

10151189
for nbar in range(n + 1):
10161190
n_diff = n - nbar
1017-
for rest in self._list_rec( nbar , k - 1):
1191+
for rest in self._list_rec(nbar, k - 1):
10181192
res.append((n_diff,) + rest)
10191193
return res
10201194

@@ -1153,17 +1327,49 @@ def rank(self, x):
11531327
if x not in self:
11541328
raise ValueError("argument is not a member of IntegerVectors({},{})".format(self.n, self.k))
11551329

1156-
n = self.n
1157-
k = self.k
1158-
1159-
r = 0
1330+
k, s, r = self.k, 0, 0
11601331
for i in range(k - 1):
1161-
k -= 1
1162-
n -= x[i]
1163-
r += binomial(k + n - 1, k)
1164-
1332+
s += x[k - 1 - i]
1333+
r += binomial(s + i, i + 1)
11651334
return r
11661335

1336+
def unrank(self, x):
1337+
"""
1338+
Return the element at given rank x.
1339+
1340+
INPUT:
1341+
1342+
- ``x`` -- an integer such that ``x < self.cardinality()``
1343+
1344+
EXAMPLES::
1345+
1346+
sage: IntegerVectors(4,5).unrank(30)
1347+
[1, 0, 1, 0, 2]
1348+
sage: IntegerVectors(2,3).unrank(5)
1349+
[0, 0, 2]
1350+
"""
1351+
if x >= self.cardinality():
1352+
raise IndexError(f"Index {x} is out of range for the IntegerVector.")
1353+
rtn = [0]*self.k
1354+
rtn[0] = self.n
1355+
return IntegerVectors._unrank_helper(self, x, rtn)
1356+
1357+
def cardinality(self):
1358+
"""
1359+
Return the cardinality of ``self``.
1360+
1361+
EXAMPLES::
1362+
1363+
sage: IntegerVectors(3,5).cardinality()
1364+
35
1365+
sage: IntegerVectors(99, 3).cardinality()
1366+
5050
1367+
sage: IntegerVectors(10^9 - 1, 3).cardinality()
1368+
500000000500000000
1369+
"""
1370+
n, k = self.n, self.k
1371+
return Integer(binomial(n + k - 1, n))
1372+
11671373

11681374
class IntegerVectors_nnondescents(UniqueRepresentation, IntegerVectors):
11691375
r"""
@@ -1320,11 +1526,11 @@ def __init__(self, n=None, k=None, **constraints):
13201526
category = FiniteEnumeratedSets()
13211527
else:
13221528
category = EnumeratedSets()
1323-
elif k is not None and 'max_part' in constraints: # n is None
1529+
elif k is not None and 'max_part' in constraints: # n is None
13241530
category = FiniteEnumeratedSets()
13251531
else:
13261532
category = EnumeratedSets()
1327-
IntegerVectors.__init__(self, category=category) # placeholder category
1533+
IntegerVectors.__init__(self, category=category) # placeholder category
13281534

13291535
def _repr_(self):
13301536
"""

0 commit comments

Comments
 (0)