Skip to content

Commit 378cd99

Browse files
authored
Fix for inconsistencies between numba and the wrapped print (#4060)
* numba bugfix * correction * moving the function * removed the default * formatting * print doesnt return anything
1 parent 360ea46 commit 378cd99

File tree

1 file changed

+56
-54
lines changed

1 file changed

+56
-54
lines changed

src/zenml/logger.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -240,69 +240,71 @@ def set_root_verbosity() -> None:
240240
get_logger(__name__).debug("Logging NOTSET")
241241

242242

243+
def wrapped_print(*args: Any, **kwargs: Any) -> None:
244+
"""Wrapped print function.
245+
246+
Args:
247+
*args: Arguments to print
248+
**kwargs: Keyword arguments for print
249+
"""
250+
original_print = getattr(builtins, "_zenml_original_print")
251+
252+
file_arg = kwargs.get("file", sys.stdout)
253+
254+
# IMPORTANT: Don't intercept internal calls to any objects
255+
# other than sys.stdout and sys.stderr. This is especially
256+
# critical for handling tracebacks. The default logging
257+
# formatter uses StringIO to format tracebacks, we don't
258+
# want to intercept it and create a LogRecord about it.
259+
if file_arg not in (sys.stdout, sys.stderr):
260+
original_print(*args, **kwargs)
261+
262+
# Convert print arguments to message
263+
message = " ".join(str(arg) for arg in args)
264+
265+
# Call active handlers first (for storage)
266+
if message.strip():
267+
handlers = logging_handlers.get()
268+
269+
for handler in handlers:
270+
try:
271+
# Create a LogRecord for the handler
272+
record = logging.LogRecord(
273+
name="print",
274+
level=logging.ERROR
275+
if file_arg == sys.stderr
276+
else logging.INFO,
277+
pathname="",
278+
lineno=0,
279+
msg=message,
280+
args=(),
281+
exc_info=None,
282+
)
283+
# Check if handler's level would accept this record
284+
if record.levelno >= handler.level:
285+
handler.emit(record)
286+
except Exception:
287+
# Don't let handler errors break print
288+
pass
289+
290+
if step_names_in_console.get():
291+
message = _add_step_name_to_message(message)
292+
293+
# Then call original print for console display
294+
original_print(message, *args[1:], **kwargs)
295+
296+
243297
def setup_global_print_wrapping() -> None:
244298
"""Set up global print() wrapping with context-aware handlers."""
245-
# Check if we should capture prints
246299
capture_prints = handle_bool_env_var(
247300
ENV_ZENML_CAPTURE_PRINTS, default=True
248301
)
249302

250-
if not capture_prints:
303+
if not capture_prints or hasattr(__builtins__, "_zenml_original_print"):
251304
return
252305

253-
# Check if already wrapped to avoid double wrapping
254-
if hasattr(__builtins__, "_zenml_original_print"):
255-
return
256-
257-
original_print = builtins.print
258-
259-
def wrapped_print(*args: Any, **kwargs: Any) -> None:
260-
file_arg = kwargs.get("file", sys.stdout)
261-
262-
# IMPORTANT: Don't intercept internal calls to any objects
263-
# other than sys.stdout and sys.stderr. This is especially
264-
# critical for handling tracebacks. The default logging
265-
# formatter uses StringIO to format tracebacks, we don't
266-
# want to intercept it and create a LogRecord about it.
267-
if file_arg not in (sys.stdout, sys.stderr):
268-
return original_print(*args, **kwargs)
269-
270-
# Convert print arguments to message
271-
message = " ".join(str(arg) for arg in args)
272-
273-
# Call active handlers first (for storage)
274-
if message.strip():
275-
handlers = logging_handlers.get()
276-
277-
for handler in handlers:
278-
try:
279-
# Create a LogRecord for the handler
280-
record = logging.LogRecord(
281-
name="print",
282-
level=logging.ERROR
283-
if file_arg == sys.stderr
284-
else logging.INFO,
285-
pathname="",
286-
lineno=0,
287-
msg=message,
288-
args=(),
289-
exc_info=None,
290-
)
291-
# Check if handler's level would accept this record
292-
if record.levelno >= handler.level:
293-
handler.emit(record)
294-
except Exception:
295-
# Don't let handler errors break print
296-
pass
297-
298-
if step_names_in_console.get():
299-
message = _add_step_name_to_message(message)
300-
301-
# Then call original print for console display
302-
return original_print(message, *args[1:], **kwargs)
303-
304306
# Store original and replace print
305-
setattr(builtins, "_zenml_original_print", original_print)
307+
setattr(builtins, "_zenml_original_print", builtins.print)
306308
setattr(builtins, "print", wrapped_print)
307309

308310

0 commit comments

Comments
 (0)