11import os
22from collections import defaultdict
33from numbers import Number
4- from typing import Any , List
4+ from typing import Any
55
66import torch
77from torch .utils ._python_dispatch import TorchDispatchMode
@@ -30,7 +30,7 @@ def prod(x):
3030 return res
3131
3232
33- def matmul_flop (inputs : List [Any ], outputs : List [Any ]) -> Number :
33+ def matmul_flop (inputs : list [Any ], outputs : list [Any ]) -> Number :
3434 """
3535 Count flops for matmul.
3636 """
@@ -43,7 +43,7 @@ def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
4343 return flop
4444
4545
46- def addmm_flop (inputs : List [Any ], outputs : List [Any ]) -> Number :
46+ def addmm_flop (inputs : list [Any ], outputs : list [Any ]) -> Number :
4747 """
4848 Count flops for fully connected layers.
4949 """
@@ -60,7 +60,7 @@ def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
6060 return flops
6161
6262
63- def bmm_flop (inputs : List [Any ], outputs : List [Any ]) -> Number :
63+ def bmm_flop (inputs : list [Any ], outputs : list [Any ]) -> Number :
6464 """
6565 Count flops for the bmm operation.
6666 """
@@ -75,9 +75,9 @@ def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
7575
7676
7777def conv_flop_count (
78- x_shape : List [int ],
79- w_shape : List [int ],
80- out_shape : List [int ],
78+ x_shape : list [int ],
79+ w_shape : list [int ],
80+ out_shape : list [int ],
8181 transposed : bool = False ,
8282) -> Number :
8383 """
@@ -99,7 +99,7 @@ def conv_flop_count(
9999 return flop
100100
101101
102- def conv_flop (inputs : List [Any ], outputs : List [Any ]):
102+ def conv_flop (inputs : list [Any ], outputs : list [Any ]):
103103 """
104104 Count flops for convolution.
105105 """
@@ -110,7 +110,7 @@ def conv_flop(inputs: List[Any], outputs: List[Any]):
110110 return conv_flop_count (x_shape , w_shape , out_shape , transposed = transposed )
111111
112112
113- def quant_conv_flop (inputs : List [Any ], outputs : List [Any ]):
113+ def quant_conv_flop (inputs : list [Any ], outputs : list [Any ]):
114114 """
115115 Count flops for quantized convolution.
116116 """
@@ -124,8 +124,8 @@ def transpose_shape(shape):
124124 return [shape [1 ], shape [0 ]] + list (shape [2 :])
125125
126126
127- def conv_backward_flop (inputs : List [Any ], outputs : List [Any ]):
128- grad_out_shape , x_shape , w_shape = [ get_shape (i ) for i in inputs [:3 ]]
127+ def conv_backward_flop (inputs : list [Any ], outputs : list [Any ]):
128+ grad_out_shape , x_shape , w_shape = ( get_shape (i ) for i in inputs [:3 ])
129129 output_mask = inputs [- 1 ]
130130 fwd_transposed = inputs [7 ]
131131 flop_count = 0
@@ -140,7 +140,7 @@ def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
140140 return flop_count
141141
142142
143- def scaled_dot_product_flash_attention_flop (inputs : List [Any ], outputs : List [Any ]):
143+ def scaled_dot_product_flash_attention_flop (inputs : list [Any ], outputs : list [Any ]):
144144 # FIXME: this needs to count the flops of this kernel
145145 # https://github.com/pytorch/pytorch/blob/207b06d099def9d9476176a1842e88636c1f714f/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L52-L267
146146 return 0
0 commit comments