Skip to content

Commit 549405e

Browse files
yushangdifacebook-github-bot
authored andcommitted
Relax aten.to restriction (pytorch#142420)
Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b. We need to remove ` torch.ops.aten._assert_tensor_metadata.default` for executorch. Update to pin to include pytorch/executorch#7277. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_float_conversion buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_device_to_mutation_float ADD_NEW_STABILITY_CONFIGS=true buck2 run 'fbcode//mode/opt' fbcode//aps_models/ads/modules/module_stability_tests/ctr_cvr_task_arch_stability_tests:inference_stability_tests buck2 run 'fbcode//mode/dev-nosan' fbcode//deeplearning/aot_inductor/cpu/test:cpu_lowering_utils_test buck2 run 'fbcode//mode/dev-nosan' //executorch/exir/tests:test_memory_format_ops_pass buck test aps_models/ads/modules/module_stability_tests/histogram_binning_calibration_tests:calibration_by_feature_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck test aps_models/ads/modules/module_stability_tests/histogram_binning_calibration_tests:calibration_stability_test -- test_stability_config --env OVERRIDE_IR_CONFIGS=true buck2 test 'fbcode//mode/opt' fbcode//aps_models/ads/modules/module_stability_tests/soft_max_loss_with_class_weights:stability_tests -- test_stability_config --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/loss_stability_tests:distillation_stability_tests -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/loss_stability_tests:stability_tests - aps_models.ads.modules.module_stability_tests.loss_stability_tests.ctr_cvr_logistic_regression_loss_stability_test.CtrCvrLogisticRegressionLossStabilityTest: test_cider_loss -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/feature_perturbation_stability_tests:sparse_feature_perturbation_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/ctr_cvr_task_arch_stability_tests:train_stability_tests -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/ctr_cvr_task_arch_stability_tests:inference_stability_tests -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/ctr_cvr_task_arch_stability_tests:histogram_binning_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/aggregator_stability_tests:replicate_aggregator_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/aggregator_stability_tests:label_aggregator_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/aggregator_stability_tests:dral_aggregator_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/module_stability_tests/adjust_task_weight_stability_tests:task_weight_stability_test -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/loss_modules/module_stability_tests/cider_stability_tests:stability_tests -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/py2pt/calibration/module_stability_tests/histogram_binning_calibration_stability_tests:stability_tests -- --env OVERRIDE_IR_CONFIGS=true buck2 test fbcode//aps_models/ads/modules/py2pt/calibration/module_stability_tests/histogram_binning_calibration_by_feature_stability_tests:stability_tests -- --env OVERRIDE_IR_CONFIGS=true ``` Reviewed By: bdhirsh, shruthign Differential Revision: D66988295
1 parent 3251171 commit 549405e

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6f638937d64e3396793956d75ee3e14802022745
1+
5190106b49125d0370736cd7fea219a6450b5f93

test/export/test_export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4894,6 +4894,29 @@ def forward(self, x):
48944894
for op in ops:
48954895
self.assertIn(op, (torch.ops.aten._to_copy.default,))
48964896

4897+
def test_float_conversion_from_int(self):
4898+
class Module(torch.nn.Module):
4899+
def forward(self, x):
4900+
return x.float()
4901+
4902+
ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions(
4903+
{}
4904+
)
4905+
ops = []
4906+
for node in ep.graph.nodes:
4907+
if node.op == "call_function":
4908+
ops.append(node.target)
4909+
self.assertGreater(len(ops), 0)
4910+
self.assertIn(torch.ops.aten._to_copy.default, ops)
4911+
self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops)
4912+
4913+
self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1)
4914+
4915+
# Raises error because the input dtype is not the same as the input
4916+
# tensor when exporting.
4917+
with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"):
4918+
ep.module()(torch.tensor(1, dtype=torch.float32))
4919+
48974920
def test_device_to_mutation_float(self):
48984921
class Module(torch.nn.Module):
48994922
def forward(self, x):

torch/_subclasses/functional_tensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,27 @@ def unwrap(x):
535535
torch.ops.aten.dropout.default,
536536
torch.ops.aten._to_copy.default,
537537
):
538-
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
538+
539+
def must_copy():
540+
"""
541+
Return True if the output of the op must be copied, not an alias
542+
"""
543+
# output dtype is different from input
544+
return (
545+
func == torch.ops.aten._to_copy.default
546+
and "dtype" in kwargs
547+
and kwargs["dtype"] != args_unwrapped[0].dtype
548+
)
549+
550+
if must_copy():
551+
# We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think
552+
# we have an aten op for that.
553+
torch.ops.aten._assert_tensor_metadata.default(
554+
torch._from_functional_tensor(args_unwrapped[0]),
555+
dtype=args_unwrapped[0].dtype,
556+
)
557+
else:
558+
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
539559
outs_wrapped = pytree.tree_map_only(
540560
torch.Tensor, wrap, outs_unwrapped
541561
)

0 commit comments

Comments
 (0)