Skip to content

Commit 04b29c8

Browse files
committed
Update
1 parent f5a62ad commit 04b29c8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

bench/bench/bench_mlp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9595

9696
# -- analyze --
9797
gf, inclusive_metrics, exclusive_metrics, device_info = viewer.read(fpath)
98-
matmuls = gf.dataframe[gf.dataframe["name"].str.contains("matmul")]
98+
# Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
99+
matmuls = gf.dataframe[gf.dataframe["name"].str.contains("matmul") & gf.dataframe["device_id"].notna()]
99100
tot_bytes = matmuls["bytes"].sum()
100101
tot_flops = sum(matmuls[[c for c in ['flops8', 'flops16'] if c in matmuls.columns]].sum())
101102
tot_time = matmuls["time (ns)"].sum()
102103

103104
# Calculate theoretical min time based on hardware limits
104-
device_type = matmuls["device_type"].dropna().iloc[0]
105-
device_id = matmuls["device_id"].dropna().iloc[0]
105+
device_type = matmuls["device_type"].iloc[0]
106+
device_id = matmuls["device_id"].iloc[0]
106107
info = device_info[device_type][device_id]
107108

108109
min_time_flops_sec = sum(

0 commit comments

Comments
 (0)