Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 104df40

Browse files
author
DEKHTIARJonathan
committed
[Benchmarking] Improving autotuning decorator logging
1 parent 903e059 commit 104df40

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

tftrt/benchmarking-python/benchmark_autotuner.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# -*- coding: utf-8 -*-
44

55
import time
6+
67
import numpy as np
78
import tensorflow as tf
89

@@ -86,31 +87,37 @@ def auto_tf_func_tuner(
8687

8788
def wrapper(func):
8889

89-
@force_gpu_resync
90-
def eager_function(*args, **kwargs):
91-
return func(*args, **kwargs)
90+
func_name = func.__name__
91+
92+
eager_function = func
9293

93-
@force_gpu_resync
94-
@tf.function(jit_compile=use_xla)
95-
def tf_function(*args, **kwargs):
96-
return func(*args, **kwargs)
94+
tf_function = tf.function(jit_compile=use_xla)(func)
9795

98-
@force_gpu_resync
99-
@_force_using_concrete_function
100-
@tf.function(jit_compile=use_xla)
101-
def tf_concrete_function(*args, **kwargs):
102-
return func(*args, **kwargs)
96+
def resync_gpu_wrap_fn(_func, str_appended):
97+
name = f"{func_name}_{str_appended}"
98+
_func.__name__ = name
99+
_func = force_gpu_resync(_func)
100+
_func.__name__ = name
101+
return _func
103102

104-
eager_function.__name__ = f"{func.__name__}_eager"
105-
tf_function.__name__ = f"{func.__name__}_tf_function"
106-
tf_concrete_function.__name__ = f"{func.__name__}_tf_concrete_function"
103+
eager_function = resync_gpu_wrap_fn(eager_function, "eager")
104+
tf_function = resync_gpu_wrap_fn(tf_function, "tf_function")
107105

108106
funcs2autotune = [eager_function, tf_function]
107+
109108
if use_synthetic_data:
110109
print(
111110
"[INFO] Allowing direct concrete_function call with "
112111
"synthetic data loader."
113112
)
113+
114+
tf_concrete_function = _force_using_concrete_function(
115+
tf.function(jit_compile=use_xla)(func)
116+
)
117+
tf_concrete_function = resync_gpu_wrap_fn(
118+
tf_concrete_function, "tf_concrete_function"
119+
)
120+
114121
funcs2autotune.append(tf_concrete_function)
115122

116123
return _TFFunctionAutoTuner(

tftrt/benchmarking-python/benchmark_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111

1212

1313
def force_gpu_resync(func):
14+
15+
func_name = func.__name__
1416
try:
1517
sync_device_fn = tf.experimental.sync_devices
16-
print("[INFO] Using API `tf.experimental.sync_devices` to resync GPUs.")
18+
print(
19+
"[INFO] Using API `tf.experimental.sync_devices` to resync GPUs "
20+
f"on function: {func_name}."
21+
)
1722

1823
def wrapper(*args, **kwargs):
1924
rslt = func(*args, **kwargs)
@@ -25,8 +30,10 @@ def wrapper(*args, **kwargs):
2530
except AttributeError:
2631
print(
2732
"[WARNING] Using deprecated API to resync GPUs. "
28-
"Non negligeable overhead might be present."
33+
"Non negligeable overhead might be present on function: "
34+
f"{func_name}."
2935
)
36+
3037
p = tf.constant(0.) # Create small tensor to force GPU resync
3138

3239
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)