Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit a40ca6a

Browse files
author
DEKHTIARJonathan
committed
[benchmarking Py] TF Logging Filtering Improved
1 parent ccc606f commit a40ca6a

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tftrt/benchmarking-python/benchmark_runner.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import csv
99
import json
10-
import logging as _logging
1110
import os
1211
import requests
1312
import sys
@@ -22,6 +21,7 @@
2221

2322
from tensorflow.python.compiler.tensorrt import trt_convert as trt
2423
from tensorflow.python.framework.errors_impl import OutOfRangeError
24+
from tensorflow.python.platform import tf_logging
2525
from tensorflow.python.saved_model import signature_constants
2626
from tensorflow.python.saved_model import tag_constants
2727

@@ -87,8 +87,19 @@ def __init__(self, args):
8787
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
8888

8989
# Hide unnecessary TensorFlow DEBUG Python Logs
90-
_logging.getLogger("tensorflow").setLevel(_logging.INFO)
91-
_logging.disable(_logging.WARNING)
90+
tf_logger = tf_logging.get_logger()
91+
tf_logger.setLevel(tf_logging.INFO)
92+
tf_logger.propagate = False
93+
94+
# disable TF warnings
95+
tf_logging.get_logger().warning = lambda *a, **kw: None
96+
tf_logging.get_logger().warn = lambda *a, **kw: None
97+
old_log = tf_logging.get_logger().log
98+
tf_logging.get_logger().log = lambda level, msg, *a, **kw: (
99+
old_log(level, msg, *a, **kw)
100+
if level != tf_logging.WARN else
101+
None
102+
)
92103

93104
# TensorFlow can execute operations synchronously or asynchronously.
94105
# If asynchronous execution is enabled, operations may return

0 commit comments

Comments
 (0)