@@ -278,14 +278,14 @@ def dynamo_timed(
278278 fail_type : Optional [str ] = None
279279 fail_reason : Optional [str ] = None
280280 time_spent = float ("-inf" )
281- start = time .time_ns ()
281+ start_ns = time .time_ns ()
282282 try :
283283 with torch .profiler .record_function (f"{ key } (dynamo_timed)" ):
284284 t0 = time .time ()
285285 if phase_name :
286- chromium_log .log_event_start (phase_name , start , {"fn_name" : key })
286+ chromium_log .log_event_start (phase_name , start_ns , {"fn_name" : key })
287287 else :
288- chromium_log .log_event_start (key , start , {})
288+ chromium_log .log_event_start (key , start_ns , {})
289289 yield
290290 time_spent = time .time () - t0
291291 compilation_time_metrics [key ].append (time_spent )
@@ -294,16 +294,17 @@ def dynamo_timed(
294294 fail_reason = str (e )
295295 raise
296296 finally :
297+ end_ns = time .time_ns ()
297298 # Always log the end event even on exception
298299 if phase_name :
299300 chromium_log .log_event_end (
300301 phase_name ,
301- time . time_ns () ,
302+ end_ns ,
302303 {},
303- start ,
304+ start_ns ,
304305 )
305306 else :
306- chromium_log .log_event_end (key , time . time_ns () , {}, start )
307+ chromium_log .log_event_end (key , end_ns , {}, start_ns )
307308 # Only record backward compilation metrics if phase_name is not None!
308309 if phase_name :
309310 frame_key = str (curr_frame )
@@ -358,17 +359,41 @@ def dynamo_timed(
358359 structured_logging_overhead_s = (
359360 torch ._logging .get_structured_logging_overhead ()
360361 )
361- metrics = BwdCompilationMetrics (
362- compile_id ,
363- inductor_compile_time ,
364- code_gen_time ,
365- fail_type ,
366- fail_reason ,
367- remote_cache_time_saved ,
368- structured_logging_overhead_s ,
369- False , # is_forward
370- to_int_ms (remote_fx_graph_cache_get_time ),
371- to_int_ms (remote_fx_graph_cache_put_time ),
362+ metrics = CompilationMetrics (
363+ compile_id = compile_id ,
364+ inductor_compile_time_s = inductor_compile_time ,
365+ code_gen_time_s = code_gen_time ,
366+ fail_type = fail_type ,
367+ fail_reason = fail_reason ,
368+ remote_cache_time_saved_s = remote_cache_time_saved ,
369+ structured_logging_overhead_s = structured_logging_overhead_s ,
370+ is_forward = False , # is_forward
371+ remote_fx_graph_cache_get_time_ms = to_int_ms (
372+ remote_fx_graph_cache_get_time
373+ ),
374+ remote_fx_graph_cache_put_time_ms = to_int_ms (
375+ remote_fx_graph_cache_put_time
376+ ),
377+ start_time_us = start_ns // 1000 ,
378+ duration_us = (end_ns - start_ns ) // 1000 ,
379+ inductor_cumulative_compile_time_us = to_int_us (
380+ inductor_compile_time
381+ ),
382+ inductor_code_gen_cumulative_compile_time_us = to_int_us (
383+ code_gen_time
384+ ),
385+ distributed_ephemeral_timeout_us = to_int_us (
386+ remote_cache_time_saved
387+ ), # TODO: instrument more accurately
388+ structured_logging_overhead_us = to_int_us (
389+ structured_logging_overhead_s
390+ ),
391+ remote_fx_graph_cache_get_time_us = to_int_us (
392+ remote_fx_graph_cache_get_time
393+ ),
394+ remote_fx_graph_cache_put_time_us = to_int_us (
395+ remote_fx_graph_cache_put_time
396+ ),
372397 )
373398 record_compilation_metrics (metrics )
374399
@@ -779,69 +804,76 @@ def to_int_ms(v: Optional[float]) -> Optional[int]:
779804 return None if v is None else int (v * 1000 )
780805
781806
807+ # float64 timestamp has a quarter microsecond precision in 2024, so while
808+ # this is suboptimal we shouldn't meaningfully lose precision
809+ def to_int_us (v : Optional [float ]) -> Optional [int ]:
810+ return None if v is None else int (v * 1_000_000 )
811+
812+
782813@dataclasses .dataclass
783814class CompilationMetrics :
784- compile_id : str
785- frame_key : str
786- co_name : str
787- co_filename : str
788- co_firstlineno : int
789- cache_size : int
790- accumulated_cache_size : int
791- guard_count : Optional [int ]
792- shape_env_guard_count : Optional [int ]
793- graph_op_count : Optional [int ]
794- graph_node_count : Optional [int ]
795- graph_input_count : Optional [int ]
796- start_time : float
797- entire_frame_compile_time_s : Optional [float ]
798- backend_compile_time_s : Optional [float ]
799- inductor_compile_time_s : Optional [float ]
800- code_gen_time_s : Optional [float ]
801- fail_type : Optional [str ]
802- fail_reason : Optional [str ]
803- fail_user_frame_filename : Optional [str ]
804- fail_user_frame_lineno : Optional [int ]
805- non_compliant_ops : Set [str ]
806- compliant_custom_ops : Set [str ]
807- restart_reasons : Set [str ]
808- dynamo_time_before_restart_s : float
815+ compile_id : Optional [ str ] = None
816+ frame_key : Optional [ str ] = None
817+ co_name : Optional [ str ] = None
818+ co_filename : Optional [ str ] = None
819+ co_firstlineno : Optional [ int ] = None
820+ cache_size : Optional [ int ] = None
821+ accumulated_cache_size : Optional [ int ] = None
822+ guard_count : Optional [int ] = None
823+ shape_env_guard_count : Optional [int ] = None
824+ graph_op_count : Optional [int ] = None
825+ graph_node_count : Optional [int ] = None
826+ graph_input_count : Optional [int ] = None
827+ start_time : Optional [ float ] = None
828+ entire_frame_compile_time_s : Optional [float ] = None
829+ backend_compile_time_s : Optional [float ] = None
830+ inductor_compile_time_s : Optional [float ] = None
831+ code_gen_time_s : Optional [float ] = None
832+ fail_type : Optional [str ] = None
833+ fail_reason : Optional [str ] = None
834+ fail_user_frame_filename : Optional [str ] = None
835+ fail_user_frame_lineno : Optional [int ] = None
836+ non_compliant_ops : Optional [ Set [str ]] = None
837+ compliant_custom_ops : Optional [ Set [str ]] = None
838+ restart_reasons : Optional [ Set [str ]] = None
839+ dynamo_time_before_restart_s : Optional [ float ] = None
809840 # Sometimes, we will finish analyzing a frame but conclude we don't want
810841 # to install any guarded code. True means we actually decided to install
811842 # a compiled frame
812- has_guarded_code : bool
813- possibly_missed_reinplacing_opportunities : Optional [int ]
814- remote_cache_time_saved_s : Optional [float ]
815- structured_logging_overhead_s : Optional [float ]
816- config_suppress_errors : Optional [bool ]
817- config_inline_inbuilt_nn_modules : Optional [bool ]
818- specialize_float : Optional [bool ]
819- dynamo_config : Optional [str ]
820- is_forward : Optional [bool ]
821- remote_fx_graph_cache_get_time_ms : Optional [int ]
822- remote_fx_graph_cache_put_time_ms : Optional [int ]
823-
824-
825- @dataclasses .dataclass
826- class BwdCompilationMetrics :
827- compile_id : str
828- inductor_compile_time_s : Optional [float ]
829- code_gen_time_s : Optional [float ]
830- fail_type : Optional [str ]
831- fail_reason : Optional [str ]
832- remote_cache_time_saved_s : Optional [float ]
833- structured_logging_overhead_s : Optional [float ]
834- is_forward : Optional [bool ]
835- remote_fx_graph_cache_get_time_ms : Optional [int ]
836- remote_fx_graph_cache_put_time_ms : Optional [int ]
843+ has_guarded_code : Optional [bool ] = None
844+ possibly_missed_reinplacing_opportunities : Optional [int ] = None
845+ remote_cache_time_saved_s : Optional [float ] = None
846+ structured_logging_overhead_s : Optional [float ] = None
847+ config_suppress_errors : Optional [bool ] = None
848+ config_inline_inbuilt_nn_modules : Optional [bool ] = None
849+ specialize_float : Optional [bool ] = None
850+ dynamo_config : Optional [str ] = None
851+ is_forward : Optional [bool ] = None
852+ remote_fx_graph_cache_get_time_ms : Optional [int ] = None
853+ remote_fx_graph_cache_put_time_ms : Optional [int ] = None
854+ start_time_us : Optional [int ] = None
855+ duration_us : Optional [int ] = None
856+ dynamo_cumulative_compile_time_us : Optional [int ] = None
857+ aot_autograd_cumulative_compile_time_us : Optional [int ] = None
858+ inductor_cumulative_compile_time_us : Optional [int ] = None
859+ inductor_code_gen_cumulative_compile_time_us : Optional [int ] = None
860+ triton_compile_time_us : Optional [int ] = None
861+ runtime_cudagraphify_time_us : Optional [int ] = None
862+ runtime_triton_autotune_time_us : Optional [int ] = None
863+ dynamo_compile_time_before_restart_us : Optional [int ] = None
864+ cuda_synchronize_time_us : Optional [int ] = None
865+ distributed_ephemeral_timeout_us : Optional [int ] = None
866+ structured_logging_overhead_us : Optional [int ] = None
867+ remote_fx_graph_cache_get_time_us : Optional [int ] = None
868+ remote_fx_graph_cache_put_time_us : Optional [int ] = None
837869
838870
839871DEFAULT_COMPILATION_METRICS_LIMIT = 64
840872
841873
842- _compilation_metrics : Deque [
843- Union [ CompilationMetrics , BwdCompilationMetrics ]
844- ] = collections . deque ( maxlen = DEFAULT_COMPILATION_METRICS_LIMIT )
874+ _compilation_metrics : Deque [CompilationMetrics ] = collections . deque (
875+ maxlen = DEFAULT_COMPILATION_METRICS_LIMIT
876+ )
845877
846878
847879def add_compilation_metrics_to_chromium (c : CompilationMetrics ):
@@ -866,21 +898,25 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
866898 fail_user_frame_filename = c .fail_user_frame_filename ,
867899 fail_user_frame_lineno = c .fail_user_frame_lineno ,
868900 # Sets aren't JSON serializable
869- non_compliant_ops = list (c .non_compliant_ops ),
870- compliant_custom_ops = list (c .compliant_custom_ops ),
871- restart_reasons = list (c .restart_reasons ),
901+ non_compliant_ops = list (c .non_compliant_ops )
902+ if c .non_compliant_ops is not None
903+ else None ,
904+ compliant_custom_ops = list (c .compliant_custom_ops )
905+ if c .compliant_custom_ops is not None
906+ else None ,
907+ restart_reasons = list (c .restart_reasons )
908+ if c .restart_reasons is not None
909+ else None ,
872910 dynamo_time_before_restart_s = c .dynamo_time_before_restart_s ,
873911 has_guarded_code = c .has_guarded_code ,
874912 dynamo_config = c .dynamo_config ,
875913 )
876914
877915
878- def record_compilation_metrics (
879- compilation_metrics : Union [CompilationMetrics , BwdCompilationMetrics ]
880- ):
916+ def record_compilation_metrics (compilation_metrics : CompilationMetrics ):
881917 global _compilation_metrics
882918 _compilation_metrics .append (compilation_metrics )
883- if isinstance ( compilation_metrics , CompilationMetrics ) :
919+ if compilation_metrics . is_forward :
884920 name = "compilation_metrics"
885921 add_compilation_metrics_to_chromium (compilation_metrics )
886922 else :
@@ -914,7 +950,7 @@ def clear_compilation_metrics() -> None:
914950 _compilation_metrics .clear ()
915951
916952
917- def get_compilation_metrics () -> List [Union [ CompilationMetrics , BwdCompilationMetrics ] ]:
953+ def get_compilation_metrics () -> List [CompilationMetrics ]:
918954 return list (_compilation_metrics )
919955
920956
@@ -957,7 +993,8 @@ def add_event_data(
957993 """
958994 if event_name not in self .get_stack ():
959995 raise RuntimeError (
960- "Cannot add metadata to events that aren't in progress."
996+ f"Event { repr (event_name )} not in { self .get_stack ()} . "
997+ "Cannot add metadata to events that aren't in progress. "
961998 "Please make sure the event has started and hasn't ended."
962999 )
9631000 event_data = self .get_event_data ()
0 commit comments