Skip to content

Commit bca696a

Browse files
ezyangpytorchmergebot
authored andcommitted
Switch times to us in CompilationMetrics and improvements (pytorch#138975)
Companion logger diff: https://www.internalfb.com/diff/D65012523 * Using float seconds for timestamps is bad because our internal system defaults to float32 precision and you don't even get second precision for timestamps in float32 * We decide to use microseconds instead of milliseconds because millisecond granularity you can end up with the same timestamp if compilation is happening very quickly; much better to force non-overlapping spans * Because there are so many new fields and I don't feel like reimplementing each on BwdCompilationMetrics, BwdCompilationMetrics is no more, it's just that everything in CompilationMetrics is now optional. * The actual frame compile times collection is not modified (still float) to reduce blast radius, so I just convert to microseconds before making the record. At float64 precision (Python's default), you get about microsecond precision on timestamps so shouldn't be a data problem (https://www.leebutterman.com/2021/02/01/store-your-unix-epoch-times-as-float64.html) * I rename some entries for clarity. In particular, whenever a timing contains all of the its lower phases (e.g., how Inductor also contains Triton compilation) we put "cumulative" in its name. If something doesn't happen at compile time but is delayed until we have actual real inputs, we put "runtime" in its name. Test plan: ``` buck2 run @mode/opt @mode/inplace //scripts/oulgen:runner ``` And then inspect https://fburl.com/scuba/dynamo_compile/sandbox/mslu7f5w and verify the us columns are populated and meaningful. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#138975 Approved by: https://github.com/masnesral
1 parent 9b2c99d commit bca696a

File tree

3 files changed

+148
-83
lines changed

3 files changed

+148
-83
lines changed

test/dynamo/test_structured_trace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def format(self, record):
110110
metadata["stack"] = "STACK"
111111
if "compilation_metrics" in metadata:
112112
metadata["compilation_metrics"] = "METRICS"
113+
if "bwd_compilation_metrics" in metadata:
114+
metadata["bwd_compilation_metrics"] = "METRICS"
113115
if "describe_storage" in metadata:
114116
metadata["describe_storage"]["describer_id"] = "ID"
115117
if "describe_tensor" in metadata:
@@ -368,7 +370,7 @@ def test_example_training_fn(self):
368370
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
369371
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
370372
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
371-
{"bwd_compilation_metrics": {"compile_id": "2/0", "inductor_compile_time_s": <dynamic>, "code_gen_time_s": <dynamic>, "fail_type": null, "fail_reason": null, "remote_cache_time_saved_s": null, "structured_logging_overhead_s": <dynamic>, "is_forward": false, "remote_fx_graph_cache_get_time_ms": null, "remote_fx_graph_cache_put_time_ms": null}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
373+
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
372374
{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
373375
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
374376
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}

torch/_dynamo/convert_frame.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
reset_graph_break_dup_checker,
121121
setup_compile_debug,
122122
to_int_ms,
123+
to_int_us,
123124
troubleshooting_url,
124125
write_record_to_file,
125126
)
@@ -964,7 +965,7 @@ def format_guard_failures() -> str:
964965
]
965966
},
966967
)
967-
start_time = time.time()
968+
start_time_ns = time.time_ns()
968969
fail_type: Optional[str] = None
969970
fail_reason: Optional[str] = None
970971
fail_user_frame_filename: Optional[str] = None
@@ -1021,6 +1022,8 @@ def format_guard_failures() -> str:
10211022
if tracer:
10221023
tracer.output.local_scope = {}
10231024

1025+
duration_ns = time.time_ns() - start_time_ns
1026+
10241027
from .utils import curr_frame
10251028

10261029
frame_key = str(curr_frame)
@@ -1089,7 +1092,7 @@ def format_guard_failures() -> str:
10891092
compliant_custom_ops = set({})
10901093
restart_reasons = set()
10911094
# If compilation failed, the entire time is wasted
1092-
dynamo_time_before_restart = time.time() - start_time
1095+
dynamo_time_before_restart = duration_ns / 1e9
10931096
possibly_missed_reinplacing_opportunities = None
10941097
remote_cache_time_saved = None
10951098
remote_fx_graph_cache_get_time = None
@@ -1124,7 +1127,7 @@ def handle_sets(d: Dict[str, Any]) -> Dict[str, Any]:
11241127
graph_op_count,
11251128
graph_node_count,
11261129
graph_input_count,
1127-
start_time,
1130+
start_time_ns / 1e9,
11281131
entire_frame_compile_time,
11291132
backend_compile_time,
11301133
inductor_compile_time,
@@ -1148,6 +1151,29 @@ def handle_sets(d: Dict[str, Any]) -> Dict[str, Any]:
11481151
True, # is_forward
11491152
to_int_ms(remote_fx_graph_cache_get_time),
11501153
to_int_ms(remote_fx_graph_cache_put_time),
1154+
start_time_us=start_time_ns // 1000,
1155+
duration_us=duration_ns // 1000,
1156+
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
1157+
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
1158+
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
1159+
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
1160+
triton_compile_time_us=None, # TODO: instrument
1161+
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
1162+
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
1163+
dynamo_compile_time_before_restart_us=to_int_us(
1164+
dynamo_time_before_restart
1165+
),
1166+
cuda_synchronize_time_us=None, # TODO: instrument
1167+
distributed_ephemeral_timeout_us=to_int_us(
1168+
remote_cache_time_saved
1169+
), # TODO: instrument more accurately
1170+
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
1171+
remote_fx_graph_cache_get_time_us=to_int_us(
1172+
remote_fx_graph_cache_get_time
1173+
),
1174+
remote_fx_graph_cache_put_time_us=to_int_us(
1175+
remote_fx_graph_cache_put_time
1176+
),
11511177
)
11521178
record_compilation_metrics(metrics)
11531179
torch._dynamo.callback_handler.run_end_callbacks()

torch/_dynamo/utils.py

Lines changed: 116 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ def dynamo_timed(
278278
fail_type: Optional[str] = None
279279
fail_reason: Optional[str] = None
280280
time_spent = float("-inf")
281-
start = time.time_ns()
281+
start_ns = time.time_ns()
282282
try:
283283
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
284284
t0 = time.time()
285285
if phase_name:
286-
chromium_log.log_event_start(phase_name, start, {"fn_name": key})
286+
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
287287
else:
288-
chromium_log.log_event_start(key, start, {})
288+
chromium_log.log_event_start(key, start_ns, {})
289289
yield
290290
time_spent = time.time() - t0
291291
compilation_time_metrics[key].append(time_spent)
@@ -294,16 +294,17 @@ def dynamo_timed(
294294
fail_reason = str(e)
295295
raise
296296
finally:
297+
end_ns = time.time_ns()
297298
# Always log the end event even on exception
298299
if phase_name:
299300
chromium_log.log_event_end(
300301
phase_name,
301-
time.time_ns(),
302+
end_ns,
302303
{},
303-
start,
304+
start_ns,
304305
)
305306
else:
306-
chromium_log.log_event_end(key, time.time_ns(), {}, start)
307+
chromium_log.log_event_end(key, end_ns, {}, start_ns)
307308
# Only record backward compilation metrics if phase_name is not None!
308309
if phase_name:
309310
frame_key = str(curr_frame)
@@ -358,17 +359,41 @@ def dynamo_timed(
358359
structured_logging_overhead_s = (
359360
torch._logging.get_structured_logging_overhead()
360361
)
361-
metrics = BwdCompilationMetrics(
362-
compile_id,
363-
inductor_compile_time,
364-
code_gen_time,
365-
fail_type,
366-
fail_reason,
367-
remote_cache_time_saved,
368-
structured_logging_overhead_s,
369-
False, # is_forward
370-
to_int_ms(remote_fx_graph_cache_get_time),
371-
to_int_ms(remote_fx_graph_cache_put_time),
362+
metrics = CompilationMetrics(
363+
compile_id=compile_id,
364+
inductor_compile_time_s=inductor_compile_time,
365+
code_gen_time_s=code_gen_time,
366+
fail_type=fail_type,
367+
fail_reason=fail_reason,
368+
remote_cache_time_saved_s=remote_cache_time_saved,
369+
structured_logging_overhead_s=structured_logging_overhead_s,
370+
is_forward=False, # is_forward
371+
remote_fx_graph_cache_get_time_ms=to_int_ms(
372+
remote_fx_graph_cache_get_time
373+
),
374+
remote_fx_graph_cache_put_time_ms=to_int_ms(
375+
remote_fx_graph_cache_put_time
376+
),
377+
start_time_us=start_ns // 1000,
378+
duration_us=(end_ns - start_ns) // 1000,
379+
inductor_cumulative_compile_time_us=to_int_us(
380+
inductor_compile_time
381+
),
382+
inductor_code_gen_cumulative_compile_time_us=to_int_us(
383+
code_gen_time
384+
),
385+
distributed_ephemeral_timeout_us=to_int_us(
386+
remote_cache_time_saved
387+
), # TODO: instrument more accurately
388+
structured_logging_overhead_us=to_int_us(
389+
structured_logging_overhead_s
390+
),
391+
remote_fx_graph_cache_get_time_us=to_int_us(
392+
remote_fx_graph_cache_get_time
393+
),
394+
remote_fx_graph_cache_put_time_us=to_int_us(
395+
remote_fx_graph_cache_put_time
396+
),
372397
)
373398
record_compilation_metrics(metrics)
374399

@@ -779,69 +804,76 @@ def to_int_ms(v: Optional[float]) -> Optional[int]:
779804
return None if v is None else int(v * 1000)
780805

781806

807+
# float64 timestamp has a quarter microsecond precision in 2024, so while
808+
# this is suboptimal we shouldn't meaningfully lose precision
809+
def to_int_us(v: Optional[float]) -> Optional[int]:
810+
return None if v is None else int(v * 1_000_000)
811+
812+
782813
@dataclasses.dataclass
783814
class CompilationMetrics:
784-
compile_id: str
785-
frame_key: str
786-
co_name: str
787-
co_filename: str
788-
co_firstlineno: int
789-
cache_size: int
790-
accumulated_cache_size: int
791-
guard_count: Optional[int]
792-
shape_env_guard_count: Optional[int]
793-
graph_op_count: Optional[int]
794-
graph_node_count: Optional[int]
795-
graph_input_count: Optional[int]
796-
start_time: float
797-
entire_frame_compile_time_s: Optional[float]
798-
backend_compile_time_s: Optional[float]
799-
inductor_compile_time_s: Optional[float]
800-
code_gen_time_s: Optional[float]
801-
fail_type: Optional[str]
802-
fail_reason: Optional[str]
803-
fail_user_frame_filename: Optional[str]
804-
fail_user_frame_lineno: Optional[int]
805-
non_compliant_ops: Set[str]
806-
compliant_custom_ops: Set[str]
807-
restart_reasons: Set[str]
808-
dynamo_time_before_restart_s: float
815+
compile_id: Optional[str] = None
816+
frame_key: Optional[str] = None
817+
co_name: Optional[str] = None
818+
co_filename: Optional[str] = None
819+
co_firstlineno: Optional[int] = None
820+
cache_size: Optional[int] = None
821+
accumulated_cache_size: Optional[int] = None
822+
guard_count: Optional[int] = None
823+
shape_env_guard_count: Optional[int] = None
824+
graph_op_count: Optional[int] = None
825+
graph_node_count: Optional[int] = None
826+
graph_input_count: Optional[int] = None
827+
start_time: Optional[float] = None
828+
entire_frame_compile_time_s: Optional[float] = None
829+
backend_compile_time_s: Optional[float] = None
830+
inductor_compile_time_s: Optional[float] = None
831+
code_gen_time_s: Optional[float] = None
832+
fail_type: Optional[str] = None
833+
fail_reason: Optional[str] = None
834+
fail_user_frame_filename: Optional[str] = None
835+
fail_user_frame_lineno: Optional[int] = None
836+
non_compliant_ops: Optional[Set[str]] = None
837+
compliant_custom_ops: Optional[Set[str]] = None
838+
restart_reasons: Optional[Set[str]] = None
839+
dynamo_time_before_restart_s: Optional[float] = None
809840
# Sometimes, we will finish analyzing a frame but conclude we don't want
810841
# to install any guarded code. True means we actually decided to install
811842
# a compiled frame
812-
has_guarded_code: bool
813-
possibly_missed_reinplacing_opportunities: Optional[int]
814-
remote_cache_time_saved_s: Optional[float]
815-
structured_logging_overhead_s: Optional[float]
816-
config_suppress_errors: Optional[bool]
817-
config_inline_inbuilt_nn_modules: Optional[bool]
818-
specialize_float: Optional[bool]
819-
dynamo_config: Optional[str]
820-
is_forward: Optional[bool]
821-
remote_fx_graph_cache_get_time_ms: Optional[int]
822-
remote_fx_graph_cache_put_time_ms: Optional[int]
823-
824-
825-
@dataclasses.dataclass
826-
class BwdCompilationMetrics:
827-
compile_id: str
828-
inductor_compile_time_s: Optional[float]
829-
code_gen_time_s: Optional[float]
830-
fail_type: Optional[str]
831-
fail_reason: Optional[str]
832-
remote_cache_time_saved_s: Optional[float]
833-
structured_logging_overhead_s: Optional[float]
834-
is_forward: Optional[bool]
835-
remote_fx_graph_cache_get_time_ms: Optional[int]
836-
remote_fx_graph_cache_put_time_ms: Optional[int]
843+
has_guarded_code: Optional[bool] = None
844+
possibly_missed_reinplacing_opportunities: Optional[int] = None
845+
remote_cache_time_saved_s: Optional[float] = None
846+
structured_logging_overhead_s: Optional[float] = None
847+
config_suppress_errors: Optional[bool] = None
848+
config_inline_inbuilt_nn_modules: Optional[bool] = None
849+
specialize_float: Optional[bool] = None
850+
dynamo_config: Optional[str] = None
851+
is_forward: Optional[bool] = None
852+
remote_fx_graph_cache_get_time_ms: Optional[int] = None
853+
remote_fx_graph_cache_put_time_ms: Optional[int] = None
854+
start_time_us: Optional[int] = None
855+
duration_us: Optional[int] = None
856+
dynamo_cumulative_compile_time_us: Optional[int] = None
857+
aot_autograd_cumulative_compile_time_us: Optional[int] = None
858+
inductor_cumulative_compile_time_us: Optional[int] = None
859+
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
860+
triton_compile_time_us: Optional[int] = None
861+
runtime_cudagraphify_time_us: Optional[int] = None
862+
runtime_triton_autotune_time_us: Optional[int] = None
863+
dynamo_compile_time_before_restart_us: Optional[int] = None
864+
cuda_synchronize_time_us: Optional[int] = None
865+
distributed_ephemeral_timeout_us: Optional[int] = None
866+
structured_logging_overhead_us: Optional[int] = None
867+
remote_fx_graph_cache_get_time_us: Optional[int] = None
868+
remote_fx_graph_cache_put_time_us: Optional[int] = None
837869

838870

839871
DEFAULT_COMPILATION_METRICS_LIMIT = 64
840872

841873

842-
_compilation_metrics: Deque[
843-
Union[CompilationMetrics, BwdCompilationMetrics]
844-
] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)
874+
_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
875+
maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
876+
)
845877

846878

847879
def add_compilation_metrics_to_chromium(c: CompilationMetrics):
@@ -866,21 +898,25 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
866898
fail_user_frame_filename=c.fail_user_frame_filename,
867899
fail_user_frame_lineno=c.fail_user_frame_lineno,
868900
# Sets aren't JSON serializable
869-
non_compliant_ops=list(c.non_compliant_ops),
870-
compliant_custom_ops=list(c.compliant_custom_ops),
871-
restart_reasons=list(c.restart_reasons),
901+
non_compliant_ops=list(c.non_compliant_ops)
902+
if c.non_compliant_ops is not None
903+
else None,
904+
compliant_custom_ops=list(c.compliant_custom_ops)
905+
if c.compliant_custom_ops is not None
906+
else None,
907+
restart_reasons=list(c.restart_reasons)
908+
if c.restart_reasons is not None
909+
else None,
872910
dynamo_time_before_restart_s=c.dynamo_time_before_restart_s,
873911
has_guarded_code=c.has_guarded_code,
874912
dynamo_config=c.dynamo_config,
875913
)
876914

877915

878-
def record_compilation_metrics(
879-
compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
880-
):
916+
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
881917
global _compilation_metrics
882918
_compilation_metrics.append(compilation_metrics)
883-
if isinstance(compilation_metrics, CompilationMetrics):
919+
if compilation_metrics.is_forward:
884920
name = "compilation_metrics"
885921
add_compilation_metrics_to_chromium(compilation_metrics)
886922
else:
@@ -914,7 +950,7 @@ def clear_compilation_metrics() -> None:
914950
_compilation_metrics.clear()
915951

916952

917-
def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]:
953+
def get_compilation_metrics() -> List[CompilationMetrics]:
918954
return list(_compilation_metrics)
919955

920956

@@ -957,7 +993,8 @@ def add_event_data(
957993
"""
958994
if event_name not in self.get_stack():
959995
raise RuntimeError(
960-
"Cannot add metadata to events that aren't in progress."
996+
f"Event {repr(event_name)} not in {self.get_stack()}. "
997+
"Cannot add metadata to events that aren't in progress. "
961998
"Please make sure the event has started and hasn't ended."
962999
)
9631000
event_data = self.get_event_data()

0 commit comments

Comments
 (0)