Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit d530562

Browse files
Mesh TensorFlow TeamCopybara-Service
authored andcommitted
Uniquify operation names: 'einsum', 'einsum_1', 'einsum_2', etc.
PiperOrigin-RevId: 219818155
1 parent 07ab904 commit d530562

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

mesh_tensorflow/ops.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def __init__(self):
350350
self._tensors = []
351351
self._trainable_variables = []
352352
self._all_variables = []
353+
# Maps a name used in the graph to the next id to use for that name.
354+
self._names_in_use = {}
353355

354356
def __repr__(self):
355357
return self.to_string
@@ -374,6 +376,37 @@ def all_variables(self):
374376
def to_string(self):
375377
return "\n".join([op.to_string for op in self.operations])
376378

379+
def unique_name(self, name, mark_as_used=True):
380+
"""Like tf.Graph.unique_name, returns a unique operation name for `name`.
381+
382+
Args:
383+
name: The name for an operation.
384+
mark_as_used: whether to mark this name as being used.
385+
386+
Returns:
387+
A string to use as the name for the operation.
388+
"""
389+
scope_name = tf.get_variable_scope().name
390+
if scope_name:
391+
name = scope_name + "/" + name
392+
393+
# As in TensorFlow, treat names as case insensitive when deciding whether
394+
# they are in use.
395+
name_key = name.lower()
396+
i = self._names_in_use.get(name_key, 0)
397+
if mark_as_used:
398+
self._names_in_use[name_key] = i + 1
399+
if i > 0:
400+
base_name_key = name_key
401+
while name_key in self._names_in_use:
402+
name_key = "%s_%d" % (base_name_key, i)
403+
i += 1
404+
if mark_as_used:
405+
self._names_in_use[name_key] = 1
406+
name = "%s_%d" % (name, i-1)
407+
408+
return name
409+
377410

378411
class Lowering(object):
379412
"""Lowering of a Graph from Mesh-TensorFlow to TensorFlow.
@@ -1143,10 +1176,7 @@ def __init__(self, inputs, mesh=None, name=None):
11431176
self._outputs = []
11441177
self._mesh = mesh
11451178
assert name is not None
1146-
scope_name = tf.get_variable_scope().name
1147-
if scope_name:
1148-
name = scope_name + "/" + name
1149-
self._name = name
1179+
self._name = mesh.graph.unique_name(name)
11501180
mesh.graph.operations.append(self)
11511181

11521182
@property

mesh_tensorflow/ops_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,29 @@ def testGraph(self):
116116
self.assertLen(graph.trainable_variables, 1)
117117
self.assertLen(graph.all_variables, 2)
118118

119+
def testGraphNames(self):
120+
# Standard Usage.
121+
graph = mtf.Graph()
122+
self.assertEqual(graph.unique_name("a"), "a")
123+
self.assertEqual(graph.unique_name("a"), "a_1")
124+
self.assertEqual(graph.unique_name("a"), "a_2")
125+
126+
# Edge cases, the user may choose the name "a_1".
127+
graph = mtf.Graph()
128+
self.assertEqual(graph.unique_name("a"), "a")
129+
self.assertEqual(graph.unique_name("a"), "a_1")
130+
self.assertEqual(graph.unique_name("a_1"), "a_1_1")
131+
132+
graph = mtf.Graph()
133+
self.assertEqual(graph.unique_name("a"), "a")
134+
self.assertEqual(graph.unique_name("a_1"), "a_1")
135+
self.assertEqual(graph.unique_name("a"), "a_2")
136+
137+
# Case insensitive.
138+
graph = mtf.Graph()
139+
self.assertEqual(graph.unique_name("a"), "a")
140+
self.assertEqual(graph.unique_name("A"), "A_1")
141+
119142
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
120143
def testLowering(self):
121144
graph = mtf.Graph()

0 commit comments

Comments
 (0)