Skip to content

Commit 7525914

Browse files
xmfanpytorchmergebot
authored andcommitted
[compiled autograd] directly use python Logger class in cpp (pytorch#137953)
Pull Request resolved: pytorch#137953 Approved by: https://github.com/jansel, https://github.com/yf225
1 parent 60c1433 commit 7525914

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

torch/_dynamo/compiled_autograd.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ def snapshot_verbose_logging_enabled():
4444
)
4545

4646

47-
def cpp_verbose_log_fn(msg: str) -> None:
48-
verbose_log.debug(msg)
49-
50-
5147
def snapshot_cudagraph_enabled():
5248
return torch._inductor.config.triton.cudagraphs
5349

@@ -546,7 +542,7 @@ def enable(compiler_fn):
546542
functools.partial(AutogradCompilerInstance, compiler_fn)
547543
)
548544
if snapshot_verbose_logging_enabled():
549-
torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
545+
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log)
550546
global compiled_autograd_enabled
551547
compiled_autograd_enabled = True
552548
try:

torch/csrc/dynamo/python_compiled_autograd.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
6969
return pyinput;
7070
}
7171

72+
// see https://github.com/pytorch/pytorch/pull/34845
73+
static void throw_python_error() {
74+
python_error err;
75+
err.persist();
76+
// NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
77+
throw err;
78+
}
79+
7280
static PyObject* check(PyObject* pyresult) {
7381
if (C10_UNLIKELY(pyresult == nullptr)) {
74-
// see https://github.com/pytorch/pytorch/pull/34845
75-
python_error err;
76-
err.persist();
77-
// NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
78-
throw err;
82+
throw_python_error();
7983
}
8084
return pyresult;
8185
}
@@ -87,18 +91,58 @@ static void check(bool result) {
8791

8892
// snapshot of python verbose logging toggle
8993
static PyObject* python_verbose_logger = nullptr;
90-
struct VerboseLogger {
94+
95+
struct PythonLogger {
96+
PythonLogger() = delete;
97+
explicit PythonLogger(PyObject* logger) : logger_(logger) {
98+
TORCH_INTERNAL_ASSERT(logger_ != nullptr);
99+
}
100+
101+
enum Level : unsigned int {
102+
DEBUG = 0,
103+
INFO = 1,
104+
WARNING = 2,
105+
ERROR = 3,
106+
CRITICAL = 4,
107+
COUNT // Keep this as the last enum
108+
};
109+
110+
// must be called while GIL is held
111+
void log(Level level, std::string_view msg) const {
112+
THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level].data()));
113+
TORCH_INTERNAL_ASSERT(pymethod != nullptr);
114+
THPObjectPtr pyfunc(PyObject_GetAttr(logger_, pymethod.get()));
115+
if (pyfunc == nullptr) {
116+
throw_python_error();
117+
}
118+
PyObject* result = PyObject_CallFunction(pyfunc.get(), "s", msg.data());
119+
if (result == nullptr) {
120+
throw_python_error();
121+
}
122+
}
123+
124+
private:
125+
static constexpr std::array<std::string_view, COUNT> levelNames_ = {
126+
"debug", // Level::DEBUG
127+
"info", // Level::INFO
128+
"warning", // Level::WARNING
129+
"error", // Level::ERROR
130+
"critical" // Level::CRITICAL
131+
};
132+
133+
// Note: logger_ must stay valid for the lifetime of this object
134+
PyObject* logger_;
135+
};
136+
137+
struct VerboseLogger : public PythonLogger {
91138
static std::optional<VerboseLogger> maybe_create() {
92139
if (python_verbose_logger == nullptr) {
93140
return std::nullopt;
94141
}
95-
return VerboseLogger();
142+
return VerboseLogger(python_verbose_logger);
96143
}
97144

98-
void verbose_log_fn(std::string_view msg) const {
99-
TORCH_CHECK(python_verbose_logger != nullptr);
100-
check(PyObject_CallFunction(python_verbose_logger, "s", msg.data()));
101-
}
145+
VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {}
102146

103147
void log_node_check(
104148
const Node& fn,
@@ -137,7 +181,7 @@ struct VerboseLogger {
137181
}
138182
}
139183
oss << "]";
140-
verbose_log_fn(oss.str());
184+
log(PythonLogger::DEBUG, oss.str());
141185
}
142186

143187
void log_dynamic_shapes_check(size_t size_idx) const {
@@ -149,10 +193,10 @@ struct VerboseLogger {
149193
TORCH_CHECK(it != cumulative_sizes_per_node.end());
150194
size_t start_idx =
151195
it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first;
152-
verbose_log_fn(
196+
log(PythonLogger::DEBUG,
153197
"Cache miss due to changed shapes: marking size idx " +
154-
std::to_string(size_idx - start_idx) + " of " + it->second +
155-
" as dynamic");
198+
std::to_string(size_idx - start_idx) + " of " + it->second +
199+
" as dynamic");
156200
}
157201

158202
// track which size index belongs to which node
@@ -347,7 +391,7 @@ static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
347391
HANDLE_TH_ERRORS;
348392
PyObject* logger = nullptr;
349393
if (!PyArg_ParseTuple(args, "O", &logger)) {
350-
Py_RETURN_FALSE;
394+
throw_python_error();
351395
}
352396

353397
if (logger == Py_None) {

0 commit comments

Comments
 (0)