22from collections import namedtuple
33import json
44import pandas as pd
5+
56try :
67 import hatchet as ht
78 from hatchet .query import NegationQuery
89except ImportError :
910 raise ImportError ("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version." )
1011import numpy as np
1112from triton .profiler .hook import COMPUTE_METADATA_SCOPE_NAME , TritonHook
13+ from triton .profiler import specs
1214
1315
1416def match_available_metrics (metrics , inclusive_metrics , exclusive_metrics ):
@@ -83,25 +85,7 @@ def get_min_time_flops(df, device_info):
8385 device_frames = df [idx ]
8486 if f"flops{ width } " not in device_frames .columns :
8587 continue
86- max_flops = 0
87- if device_type == "CUDA" :
88- if arch == "80" :
89- max_flops = 624e12 / (width / 8 )
90- elif arch == "89" :
91- # TODO(Keren): Implement fp16 acc-> 660.6 fp8
92- max_flops = (330.3 * 1e12 ) / (width / 8 )
93- elif arch == "90" :
94- # 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie
95- max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3 ) * 1513 ) * 1e12 ) / (width / 8 )
96- elif arch == "100" :
97- max_flops = (num_sms * 16384 * (clock_rate / 1e3 ) * 1e6 ) / (width / 8 )
98- elif device_type == "HIP" :
99- if arch == "gfx90a" :
100- max_flops = 383e12 / (width / 8 )
101- elif arch == "gfx941" or arch == "gfx942" :
102- max_flops = 2614.9e12 / (width / 8 )
103- else :
104- raise ValueError (f"Unsupported device type: { device_type } " )
88+ max_flops = specs .max_flops (device_type , arch , width , num_sms , clock_rate )
10589 min_time_flops .loc [idx , "min_time" ] += device_frames [f"flops{ width } " ].fillna (0 ) / max_flops
10690 return min_time_flops
10791
@@ -114,7 +98,7 @@ def get_min_time_bytes(df, device_info):
11498 device_frames = df [idx ]
11599 memory_clock_rate = device_info [device_type ][device_index ]["memory_clock_rate" ] # in khz
116100 bus_width = device_info [device_type ][device_index ]["bus_width" ] # in bits
117- peak_bandwidth = 2 * bus_width * memory_clock_rate * 1e3 / 8
101+ peak_bandwidth = specs . max_bps ( bus_width , memory_clock_rate )
118102 min_time_bytes .loc [idx , "min_time" ] += device_frames ["bytes" ] / peak_bandwidth
119103 return min_time_bytes
120104
@@ -150,7 +134,7 @@ def derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_inf
150134
151135 def get_time_seconds (df , metric , factor_dict ):
152136 time_metric_name = match_available_metrics (metric , inclusive_metrics , exclusive_metrics )[0 ]
153- time_unit = ( factor_dict .name + "/" + time_metric_name .split ("(" )[1 ].split (")" )[0 ])
137+ time_unit = factor_dict .name + "/" + time_metric_name .split ("(" )[1 ].split (")" )[0 ]
154138 return df [time_metric_name ] * factor_dict .factor [time_unit ]
155139
156140 for metric in metrics :
@@ -171,13 +155,13 @@ def get_time_seconds(df, metric, factor_dict):
171155 (get_time_seconds (gf .dataframe , "time" , time_factor_dict )) /
172156 metric_factor_dict [metric ])
173157 derived_metrics .append (f"{ metric } (inc)" )
174- elif metric in time_factor_dict .factor or metric in cpu_time_factor_dict .factor or \
175- metric in avg_time_factor_dict .factor or metric in avg_cpu_time_factor_dict .factor : # inclusive
158+ elif ( metric in time_factor_dict .factor or metric in cpu_time_factor_dict .factor
159+ or metric in avg_time_factor_dict .factor or metric in avg_cpu_time_factor_dict .factor ) : # inclusive
176160 is_cpu = metric in cpu_time_factor_dict .factor or metric in avg_cpu_time_factor_dict .factor
177161 is_avg = metric in avg_time_factor_dict .factor or metric in avg_cpu_time_factor_dict .factor
178162
179- factor_dict = (avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict ) if is_cpu \
180- else (avg_time_factor_dict if is_avg else time_factor_dict )
163+ factor_dict = (( avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict ) if is_cpu else
164+ (avg_time_factor_dict if is_avg else time_factor_dict ) )
181165 metric_name = "cpu_time" if is_cpu else "time"
182166 metric_time_unit = factor_dict .name + "/" + metric .split ("/" )[1 ]
183167
@@ -265,21 +249,26 @@ def print_tree(gf, metrics, depth=100, format=None, print_sorted=False):
265249 print ("Sorted kernels by metric " + metrics [0 ])
266250 sorted_df = gf .dataframe .sort_values (by = [metrics [0 ]], ascending = False )
267251 for row in range (1 , len (sorted_df )):
268- kernel_name = sorted_df .iloc [row ][' name' ][:100 ] + "..." if len (
269- sorted_df .iloc [row ][' name' ]) > 100 else sorted_df .iloc [row ][' name' ]
252+ kernel_name = ( sorted_df .iloc [row ][" name" ][:100 ] +
253+ "..." if len ( sorted_df .iloc [row ][" name" ]) > 100 else sorted_df .iloc [row ][" name" ])
270254 print ("{:105} {:.4}" .format (kernel_name , sorted_df .iloc [row ][metrics [0 ]]))
271255 emit_warnings (gf , metrics )
272256
273257
274- def parse ( metrics , filename , include = None , exclude = None , threshold = None ):
258+ def read ( filename ):
275259 with open (filename , "r" ) as f :
276260 gf , inclusive_metrics , exclusive_metrics , device_info = get_raw_metrics (f )
277261 assert len (inclusive_metrics + exclusive_metrics ) > 0 , "No metrics found in the input file"
278262 gf .update_inclusive_columns ()
279- metrics = derive_metrics (gf , metrics , inclusive_metrics , exclusive_metrics , device_info )
280- # TODO: generalize to support multiple metrics, not just the first one
281- gf = filter_frames (gf , include , exclude , threshold , metrics [0 ])
282- return gf , metrics
263+ return gf , inclusive_metrics , exclusive_metrics , device_info
264+
265+
266+ def parse (metrics , filename , include = None , exclude = None , threshold = None ):
267+ gf , inclusive_metrics , exclusive_metrics , device_info = read (filename )
268+ metrics = derive_metrics (gf , metrics , inclusive_metrics , exclusive_metrics , device_info )
269+ # TODO: generalize to support multiple metrics, not just the first one
270+ gf = filter_frames (gf , include , exclude , threshold , metrics [0 ])
271+ return gf , metrics
283272
284273
285274def show_metrics (file_name ):
@@ -368,23 +357,32 @@ def main():
368357 help = "The depth of the tree to display" ,
369358 )
370359 argparser .add_argument (
371- "-f" , "--format" , type = str , choices = ["full" , "file_function_line" , "function_line" , "file_function" ],
372- default = "full" , help = """Formatting the frame name.
360+ "-f" ,
361+ "--format" ,
362+ type = str ,
363+ choices = ["full" , "file_function_line" , "function_line" , "file_function" ],
364+ default = "full" ,
365+ help = """Formatting the frame name.
373366- full: include the path, file name, function name and line number.
374367- file_function_line: include the file name, function name and line number.
375368- function_line: include the function name and line number.
376369- file_function: include the file name and function name.
377- """ )
370+ """ ,
371+ )
378372 argparser .add_argument (
379373 "--print-sorted" ,
380- action = ' store_true' ,
374+ action = " store_true" ,
381375 default = False ,
382376 help = "Sort output by metric value instead of chronologically" ,
383377 )
384378 argparser .add_argument (
385- "--diff-profile" , "-diff" , type = str , default = None ,
379+ "--diff-profile" ,
380+ "-diff" ,
381+ type = str ,
382+ default = None ,
386383 help = "Compare two profiles. When used as 'proton-viewer -m time -diff file1.log file2.log', "
387- "computes the difference: file2['time'] - file1['time']" )
384+ "computes the difference: file2['time'] - file1['time']" ,
385+ )
388386
389387 args , target_args = argparser .parse_known_args ()
390388 assert len (target_args ) == 1 , "Must specify a file to read"
0 commit comments