Skip to content

Commit f6d0bef

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Makes the private all_graph_edges() function more memory efficient.
PiperOrigin-RevId: 294959528
1 parent c1fc418 commit f6d0bef

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

neural_structured_learning/tools/graph_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,13 @@ def add_undirected_edges(graph):
8989
`None`. Instead, this function has a side-effect on the `graph` argument.
9090
"""
9191
def all_graph_edges():
92-
edges = []
93-
for s, t_dict in six.iteritems(graph):
94-
for t, w in six.iteritems(t_dict):
95-
edges.append((s, t, w))
96-
return edges
92+
# Make a copy of all source IDs to avoid concurrent iteration failure.
93+
sources = list(graph.keys())
94+
for source in sources:
95+
# Make a copy of source's out-edges to avoid concurrent iteration failure.
96+
out_edges = dict(graph[source])
97+
for target, weight in six.iteritems(out_edges):
98+
yield (source, target, weight)
9799

98100
start_time = time.time()
99101
logging.info('Making all edges bi-directional...')

neural_structured_learning/tools/graph_utils_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from absl.testing import absltest
2323
from neural_structured_learning.tools import graph_utils
2424

25-
GRAPH = {'A': {'B': 0.5, 'C': 0.9}, 'B': {'A': 0.4, 'C': 1.0}}
25+
GRAPH = {'A': {'B': 0.5, 'C': 0.9}, 'B': {'A': 0.4, 'C': 1.0}, 'D': {'A': 0.75}}
2626

2727

2828
class GraphUtilsTest(absltest.TestCase):
@@ -35,6 +35,7 @@ def testAddEdge(self):
3535
graph_utils.add_edge(graph, ['A', 'C', 0.8]) # ...is used.
3636
graph_utils.add_edge(graph, ('B', 'A', '0.4'))
3737
graph_utils.add_edge(graph, ('B', 'C')) # Tests default weight
38+
graph_utils.add_edge(graph, ('D', 'A', 0.75))
3839
self.assertDictEqual(graph, GRAPH)
3940

4041
def testAddUndirectedEdges(self):
@@ -44,15 +45,19 @@ def testAddUndirectedEdges(self):
4445
g_actual, {
4546
'A': {
4647
'B': 0.5,
47-
'C': 0.9
48+
'C': 0.9,
49+
'D': 0.75
4850
},
4951
'B': {
5052
'A': 0.5, # Note, changed from 0.4 to 0.5
5153
'C': 1.0
5254
},
53-
'C': {
55+
'C': { # Added
5456
'A': 0.9, # Added
5557
'B': 1.0 # Added
58+
},
59+
'D': {
60+
'A': 0.75
5661
}
5762
})
5863

0 commit comments

Comments
 (0)