@@ -32,8 +32,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
3232 UserWarning )
3333 elif not res :
3434 raise SyntaxError ("Matrix sizes must match" )
35+
36+ # Determine output shape based on batch options
3537 if bopt :
36- shape_c = (bopt ['b' ], shape_a [- 2 ], shape_b [- 1 ])
38+ # Use batch dimensions from bopt (may be multi-dimensional)
39+ batch_dims = bopt .get ('batch_dims' , [bopt ['b' ]])
40+ shape_c = tuple (batch_dims ) + (shape_a [- 2 ], shape_b [- 1 ])
3741 else :
3842 shape_c = (shape_a [- 2 ], shape_b [- 1 ])
3943
@@ -64,16 +68,46 @@ def make_sdfg(node, parent_state, parent_sdfg):
6468
6569 state = sdfg .add_state_after (init_state , node .label + "_state" )
6670
67- state .add_mapped_tasklet (
68- '_BatchedBatchedMatMult_' , {
69- '__i%d' % i : '0:%s' % s
70- for i , s in enumerate ([bopt ['b' ], array_a .shape [- 2 ], array_b .shape [- 1 ], array_a .shape [- 1 ]])
71- }, {
72- '__a' : dace .Memlet .simple ("_a" , ('__i1, __i3' if len (array_a .shape ) == 2 else '__i0, __i1, __i3' )),
73- '__b' : dace .Memlet .simple ("_b" , ('__i3, __i2' if len (array_b .shape ) == 2 else '__i0, __i3, __i2' ))
74- },
75- '__c = __a * __b' , {'__c' : dace .Memlet .simple ("_c" , '__i0, __i1, __i2' , wcr_str = 'lambda x, y: x + y' )},
76- external_edges = True )
71+ # Calculate number of batch dimensions in output
72+ num_batch_dims = len (shape_c ) - 2
73+
74+ # Build map parameters: batch dimensions + M, N, K
75+ map_params = {}
76+ for i in range (num_batch_dims ):
77+ map_params ['__i%d' % i ] = '0:%s' % symstr (shape_c [i ])
78+
79+ # M, N, K dimensions
80+ map_params ['__im' ] = '0:%s' % symstr (shape_a [- 2 ])
81+ map_params ['__in' ] = '0:%s' % symstr (shape_b [- 1 ])
82+ map_params ['__ik' ] = '0:%s' % symstr (shape_a [- 1 ])
83+
84+ # Build memlet access patterns
85+ # For A: if 2D, use [M, K]; if 3D+, use [batch_indices..., M, K]
86+ if len (array_a .shape ) == 2 :
87+ memlet_a = '__im, __ik'
88+ else :
89+ # Use output batch indices
90+ a_batch_indices = ', ' .join (['__i%d' % i for i in range (len (array_a .shape ) - 2 )])
91+ memlet_a = f'{ a_batch_indices } , __im, __ik'
92+
93+ # For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N]
94+ if len (array_b .shape ) == 2 :
95+ memlet_b = '__ik, __in'
96+ else :
97+ b_batch_indices = ', ' .join (['__i%d' % i for i in range (len (array_b .shape ) - 2 )])
98+ memlet_b = f'{ b_batch_indices } , __ik, __in'
99+
100+ # For C: always has batch dimensions
101+ c_indices = ', ' .join (['__i%d' % i for i in range (num_batch_dims )]) + ', __im, __in'
102+
103+ state .add_mapped_tasklet ('_BatchedMatMult_' ,
104+ map_params , {
105+ '__a' : dace .Memlet .simple ("_a" , memlet_a ),
106+ '__b' : dace .Memlet .simple ("_b" , memlet_b )
107+ },
108+ '__c = __a * __b' ,
109+ {'__c' : dace .Memlet .simple ("_c" , c_indices , wcr_str = 'lambda x, y: x + y' )},
110+ external_edges = True )
77111
78112 return sdfg
79113
@@ -441,20 +475,31 @@ def validate(self, sdfg, state):
441475 raise ValueError ("Expected exactly one output from "
442476 "batched matrix-matrix product" )
443477 out_memlet = out_edges [0 ].data
444- # Function is symmetric, edge order does not matter
445- if len (size0 ) not in [2 , 3 ]:
446- raise ValueError ("Batched matrix-matrix product only supported on matrices" )
447- if len (size1 ) != 3 :
448- raise ValueError ("Batched matrix-matrix product only supported on matrices" )
478+
479+ # Both inputs must be at least 2D
480+ if len (size0 ) < 2 :
481+ raise ValueError (f"First input must be at least 2D, got shape with { len (size0 )} dimensions" )
482+ if len (size1 ) < 2 :
483+ raise ValueError (f"Second input must be at least 2D, got shape with { len (size1 )} dimensions" )
484+
485+ # At least one input must have batch dimensions (3D or higher) for batched operation
486+ if len (size0 ) <= 2 and len (size1 ) <= 2 :
487+ raise ValueError (
488+ "Batched matrix-matrix product requires at least one input to have batch dimensions (3D or higher)" )
489+
490+ # Validate K-dimension compatibility
449491 res = equal (size0 [- 1 ], size1 [- 2 ])
450492 if res is None :
451493 warnings .warn (
452494 f'First tensor\' s last mode { size0 [- 1 ]} and second tensor\' s second-last mode { size1 [- 2 ]} '
453495 f'may not match' , UserWarning )
454496 elif not res :
455497 raise ValueError ("Inputs to matrix-matrix product must agree in the k-dimension" )
456- if len (out_memlet .subset ) != 3 :
457- raise ValueError ("batched matrix-matrix product only supported on matrices" )
498+
499+ # Output must have batch dimensions
500+ if len (out_memlet .subset ) < 3 :
501+ raise ValueError (
502+ f"Batched matrix-matrix product output must be at least 3D, got { len (out_memlet .subset )} dimensions" )
458503
459504
460505# Numpy replacement
0 commit comments