11import os
2+ import sys
23
34
45os .environ ["TORCH_LOGS" ] = "inductor"
3233UNITS = {
3334 "name" : "" ,
3435 "forward_time" : " (us)" ,
36+ "teraflops" : " (TFLOPS)" ,
3537 "compilation_time" : " (s)" ,
3638}
3739PERF_OVER_ATEN_STR : str = "perf_over_aten (%)"
7577
7678
7779def benchmark_torch_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
78- return do_bench (lambda : func (* args , ** kwargs )) * 1e3
80+ return do_bench (lambda : func (* args , ** kwargs ), warmup = 100 , rep = 10000 ) * 1e3
7981
8082
8183@dataclass (frozen = True , kw_only = True )
@@ -162,6 +164,7 @@ def name(self) -> str:
162164class ExperimentResults :
163165 name : str
164166 forward_time : float
167+ teraflops : float
165168 compilation_time : float
166169
167170 def asdict (self ):
@@ -211,7 +214,10 @@ def run_single_experiment_group(
211214 for config in group_config .experiments :
212215 torch ._dynamo .reset ()
213216 torch ._inductor .utils .clear_inductor_caches ()
214- compiled_op = torch .compile (op , fullgraph = True , options = config .to_options ())
217+ compiled_op = torch .compile (
218+ op ,
219+ options = config .to_options (),
220+ )
215221
216222 start_time = time .perf_counter ()
217223 try :
@@ -227,6 +233,7 @@ def run_single_experiment_group(
227233 ExperimentResults (
228234 name = config .name (),
229235 forward_time = float ("inf" ),
236+ teraflops = 0.0 ,
230237 compilation_time = float ("inf" ),
231238 )
232239 )
@@ -238,10 +245,18 @@ def run_single_experiment_group(
238245 * inputs ,
239246 )
240247
248+ flops = calculate_flops (
249+ group_config .op_name ,
250+ group_config .shape ,
251+ group_config .batch_size ,
252+ )
253+ teraflops = flops / (forward_time * 1e-6 ) / 1e12
254+
241255 results .append (
242256 ExperimentResults (
243257 name = config .name (),
244258 forward_time = forward_time ,
259+ teraflops = teraflops ,
245260 compilation_time = compilation_time ,
246261 )
247262 )
@@ -336,6 +351,20 @@ def calculate_table_data(results: list[ExperimentResults]) -> dict:
336351 return table_data
337352
338353
354+ def calculate_flops (op_name : str , shape : tuple [int , int , int ], batch_size : int ) -> int :
355+ """
356+ Calculate the number of floating point operations based on operation type and shape.
357+ """
358+ M , N , K = shape
359+
360+ if op_name == "bmm" :
361+ return 2 * batch_size * M * N * K
362+ elif op_name == "addmm" :
363+ return 2 * M * N * K + M * N
364+ else :
365+ return 2 * M * N * K
366+
367+
339368def get_printable_results (experiment_groups : list [ExperimentGroup ]) -> list [str ]:
340369 edge_over_aten = defaultdict (list )
341370 output = []
@@ -390,8 +419,10 @@ def main():
390419 results .append (
391420 ExperimentGroup (config = group_config , results = group_results ),
392421 )
393- log .info (f"\n INTERMEDIATE results: { i } /{ len (configs )} " ) # noqa: G004
394- log .info (get_printable_results (results ))
422+ sys .stderr .write (
423+ f"\n INTERMEDIATE results: { i + 1 } /{ len (configs )} \n "
424+ + get_printable_results (results )
425+ )
395426 print ("\n FINAL results..." )
396427 print (get_printable_results (results ))
397428
0 commit comments