Skip to content

Commit db226c4

Browse files
authored
Merge pull request #22 from saibalmars/nk_fix
Networkit fix: update to support the nk's new neighbor iterator
2 parents 462e76a + 5fd3b79 commit db226c4

File tree

5 files changed

+46
-11
lines changed

5 files changed

+46
-11
lines changed

GraphRicciCurvature/OllivierRicci.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def _get_single_node_neighbors_distributions(node, direction="successors"):
7272
"""
7373
if _Gk.isDirected():
7474
if direction == "predecessors":
75-
neighbors = _Gk.inNeighbors(node)
75+
neighbors = list(_Gk.iterInNeighbors(node))
7676
else: # successors
77-
neighbors = _Gk.neighbors(node)
77+
neighbors = list(_Gk.iterNeighbors(node))
7878
else:
79-
neighbors = _Gk.neighbors(node)
79+
neighbors = list(_Gk.iterNeighbors(node))
8080

8181
# Get sum of distributions from x's all neighbors
8282
heap_weight_node_pair = []
@@ -93,10 +93,11 @@ def _get_single_node_neighbors_distributions(node, direction="successors"):
9393

9494
nbr_edge_weight_sum = sum([x[0] for x in heap_weight_node_pair])
9595

96-
if len(neighbors) == 0:
96+
if not neighbors:
9797
# No neighbor, all mass stay at node
9898
return [1], [node]
99-
elif nbr_edge_weight_sum > EPSILON:
99+
100+
if nbr_edge_weight_sum > EPSILON:
100101
# Sum need to be not too small to prevent divided by zero
101102
distributions = [(1.0 - _alpha) * w / nbr_edge_weight_sum for w, _ in heap_weight_node_pair]
102103
else:
@@ -283,10 +284,10 @@ def _average_transportation_distance(source, target):
283284

284285
t0 = time.time()
285286
if _Gk.isDirected():
286-
source_nbr = _Gk.inNeighbors(source)
287+
source_nbr = list(_Gk.iterInNeighbors(source))
287288
else:
288-
source_nbr = _Gk.neighbors(source)
289-
target_nbr = _Gk.neighbors(target)
289+
source_nbr = list(_Gk.iterNeighbors(source))
290+
target_nbr = list(_Gk.iterNeighbors(target))
290291

291292
share = (1.0 - _alpha) / (len(source_nbr) * len(target_nbr))
292293
cost_nbr = 0

GraphRicciCurvature/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.5.2"
1+
__version__ = "0.5.2.1"
22
__author__ = "Chien-Chun Ni"
33
__email__ = "saibalmars@gmail.com"

example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# -----------------------------------------------
2121
print("\n-Construct a directed graph example")
2222
Gd = nx.DiGraph()
23-
Gd.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 4), (4, 2)])
23+
Gd.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 3), (3, 1)])
2424

2525
print("\n===== Compute the Ollivier-Ricci curvature of the given directed graph Gd =====")
2626
orc_directed = OllivierRicci(Gd)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setuptools.setup(
1010
name="GraphRicciCurvature",
11-
version="0.5.2",
11+
version="0.5.2.1",
1212
author="Chien-Chun Ni",
1313
author_email="saibalmars@gmail.com",
1414
description="Compute discrete Ricci curvatures and Ricci flow on NetworkX graphs.",

test/test_OllivierRicci.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,40 @@ def test_compute_ricci_curvature():
2929
npt.assert_array_almost_equal(rc, ans)
3030

3131

32+
def test_compute_ricci_curvature_directed():
33+
Gd = nx.DiGraph()
34+
Gd.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 3), (3, 1)])
35+
orc = OllivierRicci(Gd, method="OTD", alpha=0.5)
36+
Gout = orc.compute_ricci_curvature()
37+
rc = list(nx.get_edge_attributes(Gout, "ricciCurvature").values())
38+
ans = [-0.49999999999999956,
39+
-3.842615114990622e-11,
40+
0.49999999996158007,
41+
0.49999999992677135,
42+
0.7499999999364129]
43+
44+
npt.assert_array_almost_equal(rc, ans)
45+
46+
47+
def test_compute_ricci_curvature_ATD():
48+
G = nx.karate_club_graph()
49+
orc = OllivierRicci(G, alpha=0.5, method="ATD", verbose="INFO")
50+
orc.compute_ricci_curvature()
51+
Gout = orc.compute_ricci_curvature()
52+
rc = list(nx.get_edge_attributes(Gout, "ricciCurvature").values())
53+
ans = [-0.343750, -0.437500, -0.265625, -0.250000, -0.390625, -0.390625, -0.195312, -0.443750, -0.250000,
54+
0.000000, -0.140625, -0.287500, -0.109375, -0.291667, -0.109375, -0.640625, -0.311111, -0.175926,
55+
-0.083333, -0.166667, 0.000000, -0.166667, 0.000000, -0.333333, -0.241667, -0.137500, -0.220000,
56+
-0.125000, -0.160000, -0.400000, -0.200000, -0.479167, 0.020833, 0.041667, -0.100000, -0.041667,
57+
0.055556, -0.062500, -0.041667, 0.000000, 0.000000, -0.075000, -0.275000, -0.300000, -0.176471,
58+
-0.464706, 0.000000, -0.073529, 0.000000, -0.073529, 0.000000, -0.073529, -0.421569, 0.000000,
59+
-0.073529, 0.000000, -0.073529, -0.200000, -0.200000, -0.125000, -0.291667, -0.335294, -0.055556,
60+
-0.208333, -0.194444, -0.194444, 0.062500, -0.176471, -0.375000, -0.166667, -0.245098, -0.197917,
61+
-0.227941, -0.250000, -0.294118, -0.430556, -0.455882, -0.355392]
62+
63+
npt.assert_array_almost_equal(rc, ans)
64+
65+
3266
def test_compute_ricci_flow():
3367
G = nx.karate_club_graph()
3468
orc = OllivierRicci(G, method="OTD", alpha=0.5)

0 commit comments

Comments
 (0)