Skip to content

Commit 41436fb

Browse files
committed
Adds more precise statistics about fusion
1 parent f86c55d commit 41436fb

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,8 @@ def get_parser_agg() -> ArgumentParser:
827827
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
828828
"n_model_pass,n_model_faster,"
829829
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
830+
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
831+
"n_node_layer_normalization,n_node_layer_normalization23,"
830832
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
831833
"n_node_constant,n_node_shape,n_node_expand,"
832834
"n_node_function,n_node_initializer,n_node_scatter,"

onnx_diagnostic/helpers/log_helper.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,11 @@ def view(
877877
print(f"[CubeLogs.view] key_columns={key_columns}")
878878
g = data[[*key_index, *key_columns]].copy()
879879
g["count"] = 1
880-
r = g.groupby([*key_index, *key_columns], dropna=False).sum()
880+
r = (
881+
g.copy()
882+
if not key_index and not key_columns
883+
else g.groupby([*key_index, *key_columns], dropna=False).sum()
884+
)
881885
not_unique = r[r["count"] > 1]
882886
assert not_unique.shape[0] == 0, (
883887
f"view_def.name={view_def.name!r}, "
@@ -1505,6 +1509,11 @@ def __init__(
15051509
"n_model_faster3x",
15061510
"n_model_faster4x",
15071511
"n_node_attention",
1512+
"n_node_attention23",
1513+
"n_node_rotary_embedding",
1514+
"n_node_rotary_embedding23",
1515+
"n_node_layer_normalization",
1516+
"n_node_layer_normalization23",
15081517
"n_node_control_flow",
15091518
"n_node_scatter",
15101519
"n_node_function",
@@ -1676,11 +1685,50 @@ def first_err(df: pandas.DataFrame) -> pandas.Series:
16761685
"time_latency",
16771686
gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
16781687
),
1688+
n_node_attention23=lambda df: gpreserve(
1689+
df, "op_onnx__Attention", gdf(df, "op_onnx__Attention")
1690+
),
1691+
n_node_rotary_embedding23=lambda df: gpreserve(
1692+
df, "op_onnx__RotaryEmbedding", gdf(df, "op_onnx__RotaryEmbedding")
1693+
),
1694+
n_node_layer_normalization23=lambda df: gpreserve(
1695+
df,
1696+
"time_latency",
1697+
gdf(df, "op_onnx__LayerNormalization", 0)
1698+
+ gdf(df, "op_onnx__RMSNormalization", 0)
1699+
+ gdf(df, "op_onnx__BatchNormlization", 0)
1700+
+ gdf(df, "op_onnx__InstanceNormlization", 0)
1701+
+ gdf(df, "op_onnx__GroupNormalization", 0),
1702+
),
16791703
n_node_attention=lambda df: gpreserve(
16801704
df,
1681-
"op_onnx_com.microsoft_Attention",
1682-
gdf(df, "op_onnx_com.microsoft_Attention")
1683-
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention"),
1705+
"time_latency",
1706+
gdf(df, "op_onnx_com.microsoft_Attention", 0)
1707+
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention", 0)
1708+
+ gdf(df, "op_onnx_com.microsoft_PackedAttention", 0)
1709+
+ gdf(df, "op_onnx_com.microsoft_PackedMultiHeadAttention", 0)
1710+
+ gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0)
1711+
+ gdf(df, "op_onnx_com.microsoft_PagedAttention", 0)
1712+
+ gdf(df, "op_onnx_com.microsoft_DecoderAttention", 0)
1713+
+ gdf(df, "op_onnx_com.microsoft_LongformerAttention", 0)
1714+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedSelfAttention", 0)
1715+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
1716+
+ gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
1717+
),
1718+
n_node_layer_normalization=lambda df: gpreserve(
1719+
df,
1720+
"time_latency",
1721+
gdf(df, "op_onnx_com.microsoft_EmbedLayerNormalization", 0)
1722+
+ gdf(df, "op_onnx_com.microsoft_SkipLayerNormalization", 0)
1723+
+ gdf(df, "op_onnx_com.microsoft_LayerNormalization", 0)
1724+
+ gdf(df, "op_onnx_com.microsoft_SkipSimplifiedLayerNormalization", 0)
1725+
+ gdf(df, "op_onnx_com.microsoft_SimplifiedLayerNormalization", 0),
1726+
),
1727+
n_node_rotary_embedding=lambda df: gpreserve(
1728+
df,
1729+
"time_latency",
1730+
gdf(df, "op_onnx_com.microsoft_GemmaRotaryEmbedding", 0)
1731+
+ gdf(df, "op_onnx_com.microsoft_RotaryEmbedding", 0),
16841732
),
16851733
n_node_control_flow=lambda df: gpreserve(
16861734
df,

0 commit comments

Comments
 (0)