@@ -69,13 +69,17 @@ static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
6969 return pyinput;
7070}
7171
72+ // see https://github.com/pytorch/pytorch/pull/34845
73+ static void throw_python_error () {
74+ python_error err;
75+ err.persist ();
76+ // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
77+ throw err;
78+ }
79+
7280static PyObject* check (PyObject* pyresult) {
7381 if (C10_UNLIKELY (pyresult == nullptr )) {
74- // see https://github.com/pytorch/pytorch/pull/34845
75- python_error err;
76- err.persist ();
77- // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
78- throw err;
82+ throw_python_error ();
7983 }
8084 return pyresult;
8185}
@@ -87,18 +91,58 @@ static void check(bool result) {
8791
8892// snapshot of python verbose logging toggle
8993static PyObject* python_verbose_logger = nullptr ;
90- struct VerboseLogger {
94+
95+ struct PythonLogger {
96+ PythonLogger () = delete ;
97+ explicit PythonLogger (PyObject* logger) : logger_(logger) {
98+ TORCH_INTERNAL_ASSERT (logger_ != nullptr );
99+ }
100+
101+ enum Level : unsigned int {
102+ DEBUG = 0 ,
103+ INFO = 1 ,
104+ WARNING = 2 ,
105+ ERROR = 3 ,
106+ CRITICAL = 4 ,
107+ COUNT // Keep this as the last enum
108+ };
109+
110+ // must be called while GIL is held
111+ void log (Level level, std::string_view msg) const {
112+ THPObjectPtr pymethod (PyUnicode_FromString (levelNames_[level].data ()));
113+ TORCH_INTERNAL_ASSERT (pymethod != nullptr );
114+ THPObjectPtr pyfunc (PyObject_GetAttr (logger_, pymethod.get ()));
115+ if (pyfunc == nullptr ) {
116+ throw_python_error ();
117+ }
118+ PyObject* result = PyObject_CallFunction (pyfunc.get (), " s" , msg.data ());
119+ if (result == nullptr ) {
120+ throw_python_error ();
121+ }
122+ }
123+
124+ private:
125+ static constexpr std::array<std::string_view, COUNT> levelNames_ = {
126+ " debug" , // Level::DEBUG
127+ " info" , // Level::INFO
128+ " warning" , // Level::WARNING
129+ " error" , // Level::ERROR
130+ " critical" // Level::CRITICAL
131+ };
132+
133+ // Note: logger_ must stay valid for the lifetime of this object
134+ PyObject* logger_;
135+ };
136+
137+ struct VerboseLogger : public PythonLogger {
91138 static std::optional<VerboseLogger> maybe_create () {
92139 if (python_verbose_logger == nullptr ) {
93140 return std::nullopt ;
94141 }
95- return VerboseLogger ();
142+ return VerboseLogger (python_verbose_logger );
96143 }
97144
98- void verbose_log_fn (std::string_view msg) const {
99- TORCH_CHECK (python_verbose_logger != nullptr );
100- check (PyObject_CallFunction (python_verbose_logger, " s" , msg.data ()));
101- }
145+ VerboseLogger (PyObject* vlogger) : PythonLogger(vlogger) {}
102146
103147 void log_node_check (
104148 const Node& fn,
@@ -137,7 +181,7 @@ struct VerboseLogger {
137181 }
138182 }
139183 oss << " ]" ;
140- verbose_log_fn ( oss.str ());
184+ log (PythonLogger::DEBUG, oss.str ());
141185 }
142186
143187 void log_dynamic_shapes_check (size_t size_idx) const {
@@ -149,10 +193,10 @@ struct VerboseLogger {
149193 TORCH_CHECK (it != cumulative_sizes_per_node.end ());
150194 size_t start_idx =
151195 it == cumulative_sizes_per_node.begin () ? 0 : std::prev (it)->first ;
152- verbose_log_fn (
196+ log (PythonLogger::DEBUG,
153197 " Cache miss due to changed shapes: marking size idx " +
154- std::to_string (size_idx - start_idx) + " of " + it->second +
155- " as dynamic" );
198+ std::to_string (size_idx - start_idx) + " of " + it->second +
199+ " as dynamic" );
156200 }
157201
158202 // track which size index belongs to which node
@@ -347,7 +391,7 @@ static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
347391 HANDLE_TH_ERRORS;
348392 PyObject* logger = nullptr ;
349393 if (!PyArg_ParseTuple (args, " O" , &logger)) {
350- Py_RETURN_FALSE ;
394+ throw_python_error () ;
351395 }
352396
353397 if (logger == Py_None) {
0 commit comments