55from . import OpRun , OpRunTensor
66
77
8+ class AveragePool_11 (OpRun ):
9+ "AveragePool"
10+
11+ def __init__ (self , node : onnx .NodeProto , version : Optional [int ] = None ):
12+ super ().__init__ (node , version )
13+ self .auto_pad = self .get_attribute_string (node , "auto_pad" , "NOTSET" )
14+ self .ceil_mode = bool (self .get_attribute_int (node , "ceil_mode" , 0 ))
15+ self .count_include_pad = bool (self .get_attribute_int (node , "count_include_pad" , 0 ))
16+ self .dilations = self .get_attribute_ints (node , "dilations" , None )
17+ self .kernel_shape : Tuple [int , ...] = (
18+ self .get_attribute_ints (node , "kernel_shape" ) or tuple ()
19+ )
20+ self .pads = self .get_attribute_ints (node , "pads" , None )
21+ self .strides = self .get_attribute_ints (node , "strides" , None )
22+
23+ def run (self , x ):
24+ kernel_shape = self .kernel_shape
25+ dilations = self .dilations or [1 for _ in x .shape [2 :]]
26+ strides = self .strides or [1 for _ in x .shape [2 :]]
27+ pads = self .pads or ([0 for _ in x .shape [2 :]] * 2 )
28+ assert (
29+ self .auto_pad == "NOTSET"
30+ ), f"conv not implemented for auto_pad={ self .auto_pad !r} "
31+ assert len (set (pads )) == 1 , f"conv not implemented for pads={ pads } "
32+ assert set (dilations ) == {1 }, f"conv not implemented for dilations={ dilations } "
33+ avg_pool = getattr (torch .nn .functional , f"avg_pool{ len (kernel_shape )} d" )
34+ return OpRunTensor (
35+ avg_pool (
36+ x .tensor ,
37+ kernel_size = tuple (kernel_shape ),
38+ stride = tuple (strides ),
39+ padding = pads [0 ],
40+ ceil_mode = self .ceil_mode ,
41+ count_include_pad = self .count_include_pad ,
42+ # dilation=tuple(dilations),
43+ )
44+ )
45+
46+
847class Conv_11 (OpRun ):
948 "Conv"
1049
@@ -22,15 +61,15 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2261 def run (self , x , w , b = None ):
2362 kernel_shape = self .kernel_shape or w .shape [2 :]
2463 assert (
25- tuple (kernel_shape ) == w .shape [2 :]
64+ tuple (kernel_shape ) == w .shape [- len ( kernel_shape ) :]
2665 ), f"conv not implemented for kernel_shape={ kernel_shape } and w.shape={ w .shape } "
2766 dilations = self .dilations or [1 for _ in x .shape [2 :]]
2867 strides = self .strides or [1 for _ in x .shape [2 :]]
2968 pads = self .pads or ([0 for _ in x .shape [2 :]] * 2 )
3069 assert (
3170 self .auto_pad == "NOTSET"
3271 ), f"conv not implemented for auto_pad={ self .auto_pad !r} "
33- assert len (set (pads )) == 1 , f"conv not implemented for dilations ={ pads } "
72+ assert len (set (pads )) == 1 , f"conv not implemented for pads ={ pads } "
3473 if b is None :
3574 bias = None
3675 else :
0 commit comments