Skip to content

Commit 058ab35

Browse files
yushangdifacebook-github-bot
authored andcommitted
Relax aten.to restriction (pytorch#142420)
Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b. We need to remove ` torch.ops.aten._assert_tensor_metadata.default` for executorch. Update to pin to include pytorch/executorch#7277. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_float_conversion buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_device_to_mutation_float ADD_NEW_STABILITY_CONFIGS=true buck2 run 'fbcode//mode/opt' fbcode//aps_models/ads/modules/module_stability_tests/ctr_cvr_task_arch_stability_tests:inference_stability_tests buck2 run 'fbcode//mode/dev-nosan' fbcode//deeplearning/aot_inductor/cpu/test:cpu_lowering_utils_test buck2 run 'fbcode//mode/dev-nosan' //executorch/exir/tests:test_memory_format_ops_pass ``` Reviewed By: bdhirsh Differential Revision: D66988295
1 parent 135a2d4 commit 058ab35

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6f638937d64e3396793956d75ee3e14802022745
1+
5b6d35b156067d47b0a89aa788c645b8320d8d3d

test/export/test_export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4649,6 +4649,29 @@ def forward(self, x):
46494649
for op in ops:
46504650
self.assertIn(op, (torch.ops.aten._to_copy.default,))
46514651

4652+
def test_float_conversion_from_int(self):
4653+
class Module(torch.nn.Module):
4654+
def forward(self, x):
4655+
return x.float()
4656+
4657+
ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions(
4658+
{}
4659+
)
4660+
ops = []
4661+
for node in ep.graph.nodes:
4662+
if node.op == "call_function":
4663+
ops.append(node.target)
4664+
self.assertGreater(len(ops), 0)
4665+
self.assertIn(torch.ops.aten._to_copy.default, ops)
4666+
self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops)
4667+
4668+
self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1)
4669+
4670+
# Raises error because the input dtype is not the same as the input
4671+
# tensor when exporting.
4672+
with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"):
4673+
ep.module()(torch.tensor(1, dtype=torch.float32))
4674+
46524675
def test_device_to_mutation_float(self):
46534676
class Module(torch.nn.Module):
46544677
def forward(self, x):

torch/_subclasses/functional_tensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,27 @@ def unwrap(x):
535535
torch.ops.aten.dropout.default,
536536
torch.ops.aten._to_copy.default,
537537
):
538-
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
538+
539+
def must_copy():
540+
"""
541+
Return True if the output of the op must be copied, not an alias
542+
"""
543+
# output dtype is different from input
544+
return (
545+
func == torch.ops.aten._to_copy.default
546+
and "dtype" in kwargs
547+
and kwargs["dtype"] != args_unwrapped[0].dtype
548+
)
549+
550+
if must_copy():
551+
# We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think
552+
# we have an aten op for that.
553+
torch.ops.aten._assert_tensor_metadata.default(
554+
torch._from_functional_tensor(args_unwrapped[0]),
555+
dtype=args_unwrapped[0].dtype,
556+
)
557+
else:
558+
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
539559
outs_wrapped = pytree.tree_map_only(
540560
torch.Tensor, wrap, outs_unwrapped
541561
)

0 commit comments

Comments
 (0)