Skip to content

Commit 522a0be

Browse files
katsiapistfx-copybara
authored andcommitted
Simplifying deep_copy.py (since TFT depends on sufficiently modern versions of Apache Beam), and fixing a potential variable collision (and silent overriding) for 'tag'.
PiperOrigin-RevId: 447036754
1 parent 2ac89ab commit 522a0be

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

tensorflow_transform/beam/deep_copy.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -178,41 +178,35 @@ def _clone_items(pipeline, to_clone):
178178
# merged due to common subexpression elimination (CSE).
179179
item.resource_hints['beam:resources:tags:v1'] = b'DeepCopy.Original'
180180

181-
# Assign new label and resource tag.
182-
next_suffix = 0
183-
suffix = f'Copy[{next_suffix}]'
184-
new_label = item.full_label + f'.{suffix}'
185-
tag = f'DeepCopy.{suffix}'
181+
# Assign new label.
182+
count = 0
183+
copy_suffix = f'Copy[{count}]'
184+
new_label = f'{item.full_label}.{copy_suffix}'
186185
while new_label in pipeline.applied_labels:
187-
suffix = f'Copy[{next_suffix}]'
188-
new_label = item.full_label + f'.{suffix}'
189-
tag = f'DeepCopy.{suffix}'
190-
next_suffix += 1
186+
count += 1
187+
copy_suffix = f'Copy[{count}]'
188+
new_label = f'{item.full_label}.{copy_suffix}'
191189
pipeline.applied_labels.add(new_label)
192190

193191
# Update inputs.
194-
if hasattr(item, 'main_inputs'):
195-
new_inputs = {
196-
tag: pcollection_replacements.get(old_input, old_input)
197-
for tag, old_input in item.main_inputs.items()
198-
}
199-
else:
200-
new_inputs = tuple(
201-
pcollection_replacements.get(old_input, old_input)
202-
for old_input in item.inputs)
192+
new_inputs = {
193+
tag: pcollection_replacements.get(old_input, old_input)
194+
for tag, old_input in item.main_inputs.items()
195+
}
203196

204197
# Create the copy. Note that in the copy, copied.outputs will start out
205198
# empty. Any outputs that are used will be repopulated in the PCollection
206199
# copy branch above.
207200
copied = beam_pipeline.AppliedPTransform(item.parent, item.transform,
208201
new_label, new_inputs)
209202

210-
# Add a unique resource tag to the copied PTransforms. The PTransforms
211-
# that are generated from each deep copy have the same unique tag. This is
212-
# to make sure that the PTransforms that are cloned from each deep copy
213-
# can be fused together, but not across copies nor the original.
203+
# Add a resource tag to the copied PTransforms. The PTransforms that are
204+
# generated from each deep copy have the same unique tag. This is to make
205+
# sure that the PTransforms that are cloned from each deep copy can be
206+
# fused together, but not across copies nor with the original.
214207
if tags_resource_available:
215-
copied.resource_hints['beam:resources:tags:v1'] = tag.encode()
208+
copied.resource_hints['beam:resources:tags:v1'] = (
209+
f'DeepCopy.{copy_suffix}'.encode())
216210

217211
ptransform_replacements[item] = copied
218212

0 commit comments

Comments
 (0)