Skip to content

Commit 5b9db43

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Include c++ stack traces when we hit constraint violation (pytorch#155603)
Example new error message ``` torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[0])! For more information, run with TORCH_LOGS="+dynamic". - You marked L['x'].size()[0] as dynamic but your code specialized it to be a constant (5). Either remove the mark_dynamic or use a less strict API such as maybe_mark_dynamic or Dim.AUTO. Framework stack: File "??", line 0, in _start File "", line 0, in __libc_start_main_alias_2 File "??", line 0, in __libc_start_call_main File "/usr/local/src/conda/python-3.10.16/Modules/main.c", line 1094, in Py_BytesMain File "/usr/local/src/conda/python-3.10.16/Modules/main.c", line 357, in pymain_run_file_obj File "/usr/local/src/conda/python-3.10.16/Python/pythonrun.c", line 90, in _PyRun_AnyFileObject File "/usr/local/src/conda/python-3.10.16/Python/pythonrun.c", line 456, in _PyRun_SimpleFileObject File "/usr/local/src/conda/python-3.10.16/Python/pythonrun.c", line 1208, in pyrun_file File "/usr/local/src/conda/python-3.10.16/Python/pythonrun.c", line 1312, in run_mod File "/usr/local/src/conda/python-3.10.16/Python/pythonrun.c", line 1291, in run_eval_code_obj File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 1134, in PyEval_EvalCode File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/scratch/repro.py", line 9, in <module> foo(x) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper return fn(*args, **kwargs) File "offloadstuff.c", line 0, in dynamo__custom_eval_frame File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 305, in _PyObject_Call File "/usr/local/src/conda/python-3.10.16/Objects/typeobject.c", line 7494, in slot_tp_call File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 431, in _PyObject_Call_Prepend File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 1469, in __call__ return self._torchdynamo_orig_callable( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Objects/typeobject.c", line 7494, in slot_tp_call File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 431, in _PyObject_Call_Prepend File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 153, in _PyObject_FastCallDictTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 1248, in __call__ result = self._inner_convert( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Objects/typeobject.c", line 7494, in slot_tp_call File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 431, in _PyObject_Call_Prepend File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 153, in _PyObject_FastCallDictTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 625, in __call__ return _compile( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 1092, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_utils_internal.py", line 97, in wrapper_function return function(*args, **kwargs) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 779, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 818, in _compile_inner out_code = transform_code_object(code, transform) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object transformations(instructions, code_options) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 265, in _fn return fn(*args, **kwargs) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/convert_frame.py", line 743, in transform tracer.run() File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3531, in run super().run() File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1359, in run while self.step(): File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1263, in step self.dispatch_table[inst.opcode](self, inst) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/symbolic_convert.py", line 422, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builtin.py", line 1160, in call_function return handler(tx, args, kwargs) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builtin.py", line 792, in <lambda> return lambda tx, args, kwargs: obj.call_function( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builtin.py", line 1160, in call_function return handler(tx, args, kwargs) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builtin.py", line 1120, in _handle_insert_op_in_graph return wrap_fx_proxy(tx, proxy) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builder.py", line 2500, in wrap_fx_proxy return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 267, in PyVectorcall_Call File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builder.py", line 2566, in wrap_fx_proxy_cls return _wrap_fx_proxy( File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/variables/builder.py", line 2664, in _wrap_fx_proxy example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/utils.py", line 3205, in get_fake_value ret_val = wrap_fake_exception( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/utils.py", line 2705, in wrap_fake_exception return fn() File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/utils.py", line 3206, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_dynamo/utils.py", line 3373, in run_node return node.target(*args, **kwargs) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5917, in do_call_core File "/usr/local/src/conda/python-3.10.16/Objects/methodobject.c", line 430, in cfunction_vectorcall_FASTCALL File "/usr/local/src/conda/python-3.10.16/Objects/abstract.c", line 891, in binary_op1 File "/usr/local/src/conda/python-3.10.16/Objects/typeobject.c", line 7284, in slot_nb_multiply File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/descrobject.c", line 344, in method_vectorcall_VARARGS_KEYWORDS File "python_variable_methods.cpp", line 0, in _object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_mul>(_object*, _object*, _object*) File "python_variable_methods.cpp", line 0, in torch::autograd::THPVariable_mul(_object*, _object*, _object*) File "??", line 0, in at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&) File "offloadstuff.c", line 0, in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const File "offloadstuff.c", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const File "PythonFallbackKernel.cpp", line 0, in void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::pythonTLSSnapshotFallback>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const File "offloadstuff.c", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const File "VariableType_0.cpp", line 0, in c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mul_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) File "VariableType_0.cpp", line 0, in torch::autograd::VariableType::(anonymous namespace)::mul_Tensor(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) File "??", line 0, in at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) File "offloadstuff.c", line 0, in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const File "offloadstuff.c", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const File "PythonFallbackKernel.cpp", line 0, in (anonymous namespace)::pythonFallback(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::dispatch(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const File "??", line 0, in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 577, in PyObject_CallMethod File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/utils/_stats.py", line 27, in wrapper return fn(*args, **kwargs) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1346, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2029, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1442, in _cached_dispatch_impl return self._dispatch_impl(func, types, args, kwargs) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2552, in _dispatch_impl return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs)) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_impls.py", line 956, in fast_binary_impl final_shape = infer_size(final_shape, shape) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_impls.py", line 916, in infer_size torch._check( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/__init__.py", line 1669, in _check _check_with(RuntimeError, cond, message) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/__init__.py", line 1632, in _check_with if expect_true(cond): File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1686, in expect_true return a.node.expect_true( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/sym_node.py", line 552, in expect_true return self.guard_bool(file, line) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/sym_node.py", line 536, in guard_bool r = self.evaluate() File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/sym_node.py", line 510, in evaluate return self.shape_env.evaluate_sym_node(self, size_oblivious) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7113, in evaluate_sym_node return self.evaluate_expr( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Modules/_functoolsmodule.c", line 1020, in bounded_lru_cache_wrapper File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 267, in PyVectorcall_Call File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/recording.py", line 272, in wrapper return retlog(fn(*args, **kwargs)) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 267, in PyVectorcall_Call File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7215, in evaluate_expr return self._inner_evaluate_expr( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Modules/_functoolsmodule.c", line 1020, in bounded_lru_cache_wrapper File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/recording.py", line 272, in wrapper return retlog(fn(*args, **kwargs)) File "/usr/local/src/conda/python-3.10.16/Python/ceval.c", line 5945, in do_call_core File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7238, in _inner_evaluate_expr return self._evaluate_expr( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7505, in _evaluate_expr self._maybe_guard_rel(g) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Modules/_functoolsmodule.c", line 1020, in bounded_lru_cache_wrapper File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6758, in _maybe_guard_rel self._refine_ranges(expr) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7709, in _refine_ranges self._set_replacement( File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6667, in _set_replacement self.framework_specialization_stacks[source] = CapturedTraceback.extract(cpp=True) File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 114, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Include/internal/pycore_ceval.h", line 46, in _PyEval_EvalFrame File "/home/bobren/local/a/pytorch/torch/utils/_traceback.py", line 207, in extract torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), File "/usr/local/src/conda/python-3.10.16/Include/cpython/abstract.h", line 112, in _PyObject_VectorcallTstate File "/usr/local/src/conda/python-3.10.16/Objects/call.c", line 215, in _PyObject_MakeTpCall File "/usr/local/src/conda/python-3.10.16/Objects/methodobject.c", line 543, in cfunction_call File "offloadstuff.c", line 0, in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) File "offloadstuff.c", line 0, in pybind11::cpp_function::initialize<std::shared_ptr<torch::CapturedTraceback> (*&)(bool, bool, bool), std::shared_ptr<torch::CapturedTraceback>, bool, bool, bool, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v>(std::shared_ptr<torch::CapturedTraceback> (*&)(bool, bool, bool), std::shared_ptr<torch::CapturedTraceback> (*)(bool, bool, bool), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) File "??", line 0, in torch::CapturedTraceback::gather(bool, bool, bool) File "??", line 0, in torch::unwind::unwind() User stack: File "/home/bobren/local/a/pytorch/scratch/repro.py", line 5, in foo return torch.randn(5) * x ``` Pull Request resolved: pytorch#155603 Approved by: https://github.com/zou3519, https://github.com/cyyever ghstack dependencies: pytorch#155133
1 parent 84c1436 commit 5b9db43

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,7 +3837,8 @@ def _init(
38373837
# with something like effect token tracking.
38383838
self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
38393839

3840-
self.specialization_stacks: dict[Source, traceback.StackSummary] = {}
3840+
self.user_specialization_stacks: dict[Source, traceback.StackSummary] = {}
3841+
self.framework_specialization_stacks: dict[Source, traceback.StackSummary] = {}
38413842

38423843
self.trace_asserts = trace_asserts
38433844

@@ -3968,7 +3969,8 @@ def check_equal(self, other: ShapeEnv) -> None:
39683969
"replacements_slocs",
39693970
"_resimplify_floor_div_axioms",
39703971
"_expr_sym_node_id",
3971-
"specialization_stacks",
3972+
"user_specialization_stacks",
3973+
"framework_specialization_stacks",
39723974
)
39733975

39743976
# Mapping of the value of each to-be-compared field into the values that
@@ -5515,11 +5517,19 @@ def hint(s: sympy.Expr) -> str:
55155517
var_with_range = self._render_range_for_constraint_violation(
55165518
source, constraint
55175519
)
5518-
user_stack = self.specialization_stacks.get(source, None)
5520+
user_stack = self.user_specialization_stacks.get(source, None)
5521+
framework_stack = self.framework_specialization_stacks.get(
5522+
source, None
5523+
)
55195524
msg = (
55205525
f"You marked {self._debug_name(source)} as dynamic but your code "
55215526
f"specialized it to be a constant ({val}). Either remove the mark_dynamic "
55225527
f"or use a less strict API such as maybe_mark_dynamic or Dim.AUTO."
5528+
+ (
5529+
"\n\nFramework stack:\n" + "".join(framework_stack.format())
5530+
if framework_stack
5531+
else ""
5532+
)
55235533
+ (
55245534
"\n\nUser stack:\n" + "".join(user_stack.format())
55255535
if user_stack
@@ -6742,7 +6752,10 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
67426752

67436753
for source in self.var_to_sources.get(a, []):
67446754
if user_tb:
6745-
self.specialization_stacks[source] = user_tb
6755+
self.user_specialization_stacks[source] = user_tb
6756+
self.framework_specialization_stacks[
6757+
source
6758+
] = CapturedTraceback.extract(cpp=True)
67466759

67476760
if config.print_specializations:
67486761
self.log.warning(

0 commit comments

Comments
 (0)