1515#include " tensorflow/core/platform/init_main.h"
1616#include " tensorflow/core/platform/logging.h"
1717#include " tensorflow/core/platform/types.h"
18+ #include " tensorflow/core/profiler/lib/profiler_session.h"
19+ #include " tensorflow/core/profiler/lib/traceme.h"
20+ #include " tensorflow/core/profiler/rpc/client/capture_profile.h"
1821#include " tensorflow/core/public/session.h"
1922#include " tensorflow/core/util/command_line_flags.h"
2023
@@ -142,6 +145,24 @@ Status SetupCallable(std::unique_ptr<tensorflow::Session>& session,
142145 return session->MakeCallable (opts, handle);
143146}
144147
148+ // Start the profiling session.
149+ Status StartProfiling (std::unique_ptr<tensorflow::ProfilerSession>& profiler) {
150+ profiler = tensorflow::ProfilerSession::Create (
151+ tensorflow::ProfilerSession::DefaultOptions ()
152+ );
153+ return profiler->Status ();
154+ }
155+
156+ // Tear down the profiler and export tensorboard logs.
157+ Status StopProfiling (std::unique_ptr<tensorflow::ProfilerSession>& profiler,
158+ const string& out_dir) {
159+ tensorflow::profiler::XSpace xspace;
160+ TF_RETURN_IF_ERROR (profiler->CollectData (&xspace));
161+ tensorflow::profiler::ExportToTensorBoard (xspace, out_dir);
162+ profiler.reset ();
163+ return Status::OK ();
164+ }
165+
145166int main (int argc, char * argv[]) {
146167 // Parse arguments
147168 string model_path = " /path/to/model/" ;
@@ -151,6 +172,7 @@ int main(int argc, char* argv[]) {
151172 int32_t eval_iters = 800 ;
152173 bool input_from_device = true ;
153174 bool output_to_host = true ;
175+ string out_dir = " " ;
154176 std::vector<Flag> flag_list = {
155177 Flag (" model_path" , &model_path, " graph to be executed" ),
156178 Flag (" signature_key" , &signature_key, " the serving signature to use" ),
@@ -159,6 +181,7 @@ int main(int argc, char* argv[]) {
159181 Flag (" eval_iters" , &eval_iters, " number of timed iterations to run" ),
160182 Flag (" input_from_device" , &input_from_device, " use inputs from device, rather than host" ),
161183 Flag (" output_to_host" , &output_to_host, " copy outputs to host after inference" ),
184+ Flag (" out_dir" , &out_dir, " if set, runs the profiler and exports to this directory" ),
162185 };
163186 string usage = tensorflow::Flags::Usage (argv[0 ], flag_list);
164187 const bool parse_result = tensorflow::Flags::Parse (&argc, argv, flag_list);
@@ -205,18 +228,29 @@ int main(int argc, char* argv[]) {
205228 std::chrono::steady_clock::time_point eval_start_time;
206229 std::chrono::steady_clock::time_point start_time;
207230 std::chrono::steady_clock::time_point end_time;
231+ std::unique_ptr<tensorflow::ProfilerSession> profiler;
208232 for (int i = 0 ; i < warmup_iters + eval_iters; i++) {
209233 if (i == warmup_iters) {
210234 LOG (INFO) << " Warmup done" ;
235+ if (!out_dir.empty ()) {
236+ StartProfiling (profiler);
237+ }
211238 eval_start_time = std::chrono::steady_clock::now ();
212239 }
213240
214- start_time = std::chrono::steady_clock::now ();
215- TFTRT_ENSURE_OK (
216- bundle.session ->RunCallable (handle, inputs_device, &outputs, nullptr ));
217- // Sync, as `set_fetch_skip_sync(false)` is currently not implemented
218- TFTRT_ENSURE_OK (device->Sync ());
219- end_time = std::chrono::steady_clock::now ();
241+ {
242+ tensorflow::profiler::TraceMe trace ([&i, &warmup_iters]() {
243+ return tensorflow::profiler::TraceMeEncode (
244+ " gpu_compute" , {{" iter" , i - warmup_iters}}
245+ );
246+ }, 1 );
247+ start_time = std::chrono::steady_clock::now ();
248+ TFTRT_ENSURE_OK (
249+ bundle.session ->RunCallable (handle, inputs_device, &outputs, nullptr ));
250+ // Sync, as `set_fetch_skip_sync(false)` is currently not implemented
251+ TFTRT_ENSURE_OK (device->Sync ());
252+ end_time = std::chrono::steady_clock::now ();
253+ }
220254
221255 if ((i % 10 ) == 0 ) {
222256 LOG (INFO) << " step: " << i;
@@ -225,6 +259,9 @@ int main(int argc, char* argv[]) {
225259 double duration = (end_time - start_time).count () / 1e6 ;
226260 infer_time.push_back (duration);
227261 }
262+ if (!out_dir.empty ()) {
263+ StopProfiling (profiler, out_dir);
264+ }
228265 TFTRT_ENSURE_OK (bundle.session ->ReleaseCallable (handle));
229266
230267 // Print results
0 commit comments