Skip to content

Commit c4c9447

Browse files
author
Pratyai Mazumder
committed
Merge remote-tracking branch 'origin/main' into f2dace-windmill
2 parents 44b6b63 + 93d8049 commit c4c9447

File tree

30 files changed

+2512
-106
lines changed

30 files changed

+2512
-106
lines changed

.github/workflows/general-ci.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
else
6060
export DACE_optimizer_automatic_simplification=${{ matrix.simplify }}
6161
fi
62-
pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long"
62+
pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long and not sequential"
6363
./codecov
6464
6565
- name: Test OpenBLAS LAPACK
@@ -78,6 +78,22 @@ jobs:
7878
pytest -n 1 --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "lapack"
7979
./codecov
8080
81+
- name: Run sequential tests
82+
run: |
83+
export NOSTATUSBAR=1
84+
export DACE_testing_serialization=1
85+
export DACE_testing_deserialize_exception=1
86+
export DACE_cache=unique
87+
if [ "${{ matrix.simplify }}" = "autoopt" ]; then
88+
export DACE_optimizer_automatic_simplification=1
89+
export DACE_optimizer_autooptimize=1
90+
echo "Auto-optimization heuristics"
91+
else
92+
export DACE_optimizer_automatic_simplification=${{ matrix.simplify }}
93+
fi
94+
pytest -n 1 --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "sequential"
95+
./codecov
96+
8197
- name: Run other tests
8298
run: |
8399
export NOSTATUSBAR=1

dace/cli/dacelab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
2+
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
33

44
import argparse
55
from dace.frontend.octave import parse

dace/codegen/compiled_sdfg.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import subprocess
88
from typing import Any, Callable, Dict, List, Tuple, Optional, Type, Union
99
import warnings
10+
import tempfile
11+
import pickle
12+
import sys
1013

1114
import numpy as np
1215
import sympy as sp
@@ -414,6 +417,59 @@ def __call__(self, *args, **kwargs):
414417
# Return values are cached in `self._lastargs`.
415418
return self.fast_call(argtuple, initargtuple, do_gpu_check=True)
416419

420+
def safe_call(self, *args, **kwargs):
421+
"""
422+
Forwards the Python call to the compiled ``SDFG`` in a separate process to avoid crashes in the main process. Raises an exception if the SDFG execution fails.
423+
"""
424+
425+
# Pickle the SDFG and arguments
426+
with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f:
427+
pickle.dump({
428+
'library_path': self._lib._library_filename,
429+
"sdfg": self.sdfg,
430+
'args': args,
431+
'kwargs': kwargs
432+
}, f)
433+
temp_path = f.name
434+
435+
# Call the SDFG in a separate process
436+
result = subprocess.run([
437+
sys.executable, '-c', f'''
438+
import pickle
439+
from dace.codegen import compiled_sdfg as csd
440+
441+
with open(r"{temp_path}", "rb") as f:
442+
data = pickle.load(f)
443+
library_path = data['library_path']
444+
sdfg = data['sdfg']
445+
446+
lib = csd.ReloadableDLL(library_path, sdfg.name)
447+
obj = csd.CompiledSDFG(sdfg, lib, sdfg.arg_names)
448+
obj(*data['args'], **data['kwargs'])
449+
450+
with open(r"{temp_path}", "wb") as f:
451+
pickle.dump({{
452+
'args': data['args'],
453+
'kwargs': data['kwargs']
454+
}}, f)
455+
'''
456+
])
457+
458+
# Receive the result
459+
with open(temp_path, 'rb') as f:
460+
data = pickle.load(f)
461+
for i in range(len(args)):
462+
if hasattr(args[i], '__setitem__'):
463+
args[i].__setitem__(slice(None), data['args'][i])
464+
for k in kwargs:
465+
if hasattr(kwargs[k], '__setitem__'):
466+
kwargs[k].__setitem__(slice(None), data['kwargs'][k])
467+
468+
# Clean up
469+
os.remove(temp_path)
470+
if result.returncode != 0:
471+
raise RuntimeError(f'SDFG execution failed with return code {result.returncode}.')
472+
417473
def fast_call(
418474
self,
419475
callargs: Tuple[Any, ...],

dace/codegen/targets/cpp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,8 +1187,11 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT
11871187
# - Soft-squeeze the slice (remove unit-modes) to match the treatment of the strides above.
11881188
if target not in self.constants:
11891189
desc = self.sdfg.arrays[dname]
1190-
if isinstance(desc, data.Array) and data._prod(desc.shape) != 1:
1191-
elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1]
1190+
if sum(1 for s in desc.shape if s != 1) != len(visited_slice.elts):
1191+
if isinstance(desc, data.Array) and data._prod(desc.shape) != 1:
1192+
elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1]
1193+
else:
1194+
elts = visited_slice.elts
11921195
else:
11931196
elts = visited_slice.elts
11941197
if len(strides) != len(elts):

dace/frontend/python/newast.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,9 @@ def add_indirection_subgraph(sdfg: SDFG,
536536
for i, idx in enumerate(nonsqz_dims):
537537
newsubset[idx] = '__i%d' % i
538538

539+
# Squeeze size-1 dimensions out of expression
540+
newsubset = [s for shp, s in zip(array.shape, newsubset) if shp != 1]
541+
539542
tasklet.code = CodeBlock(
540543
code.format(arr='__ind_' + local_name, index=', '.join([symbolic.symstr(s) for s in newsubset])))
541544

dace/frontend/python/replacements/linalg.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
2828

2929
if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix
3030

31-
if len(arr1.shape) > 3 or len(arr2.shape) > 3:
32-
raise SyntaxError('Matrix multiplication of tensors of dimensions > 3 not supported')
33-
3431
res = symbolic.equal(arr1.shape[-1], arr2.shape[-2])
3532
if res is None:
3633
warnings.warn(
@@ -41,10 +38,12 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
4138

4239
from dace.libraries.blas.nodes.matmul import _get_batchmm_opts
4340

44-
# Determine batched multiplication
41+
# Determine batched multiplication (supports N-D tensors)
4542
bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, None, None)
4643
if bopt:
47-
output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1])
44+
# Multi-dimensional batch: use batch_dims if available, otherwise use flattened batch size
45+
batch_dims = bopt.get('batch_dims', [bopt['b']])
46+
output_shape = tuple(batch_dims) + (arr1.shape[-2], arr2.shape[-1])
4847
else:
4948
output_shape = (arr1.shape[-2], arr2.shape[-1])
5049

dace/libraries/blas/nodes/batched_matmul.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
3232
UserWarning)
3333
elif not res:
3434
raise SyntaxError("Matrix sizes must match")
35+
36+
# Determine output shape based on batch options
3537
if bopt:
36-
shape_c = (bopt['b'], shape_a[-2], shape_b[-1])
38+
# Use batch dimensions from bopt (may be multi-dimensional)
39+
batch_dims = bopt.get('batch_dims', [bopt['b']])
40+
shape_c = tuple(batch_dims) + (shape_a[-2], shape_b[-1])
3741
else:
3842
shape_c = (shape_a[-2], shape_b[-1])
3943

@@ -64,16 +68,46 @@ def make_sdfg(node, parent_state, parent_sdfg):
6468

6569
state = sdfg.add_state_after(init_state, node.label + "_state")
6670

67-
state.add_mapped_tasklet(
68-
'_BatchedBatchedMatMult_', {
69-
'__i%d' % i: '0:%s' % s
70-
for i, s in enumerate([bopt['b'], array_a.shape[-2], array_b.shape[-1], array_a.shape[-1]])
71-
}, {
72-
'__a': dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape) == 2 else '__i0, __i1, __i3')),
73-
'__b': dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape) == 2 else '__i0, __i3, __i2'))
74-
},
75-
'__c = __a * __b', {'__c': dace.Memlet.simple("_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y')},
76-
external_edges=True)
71+
# Calculate number of batch dimensions in output
72+
num_batch_dims = len(shape_c) - 2
73+
74+
# Build map parameters: batch dimensions + M, N, K
75+
map_params = {}
76+
for i in range(num_batch_dims):
77+
map_params['__i%d' % i] = '0:%s' % symstr(shape_c[i])
78+
79+
# M, N, K dimensions
80+
map_params['__im'] = '0:%s' % symstr(shape_a[-2])
81+
map_params['__in'] = '0:%s' % symstr(shape_b[-1])
82+
map_params['__ik'] = '0:%s' % symstr(shape_a[-1])
83+
84+
# Build memlet access patterns
85+
# For A: if 2D, use [M, K]; if 3D+, use [batch_indices..., M, K]
86+
if len(array_a.shape) == 2:
87+
memlet_a = '__im, __ik'
88+
else:
89+
# Use output batch indices
90+
a_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_a.shape) - 2)])
91+
memlet_a = f'{a_batch_indices}, __im, __ik'
92+
93+
# For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N]
94+
if len(array_b.shape) == 2:
95+
memlet_b = '__ik, __in'
96+
else:
97+
b_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_b.shape) - 2)])
98+
memlet_b = f'{b_batch_indices}, __ik, __in'
99+
100+
# For C: always has batch dimensions
101+
c_indices = ', '.join(['__i%d' % i for i in range(num_batch_dims)]) + ', __im, __in'
102+
103+
state.add_mapped_tasklet('_BatchedMatMult_',
104+
map_params, {
105+
'__a': dace.Memlet.simple("_a", memlet_a),
106+
'__b': dace.Memlet.simple("_b", memlet_b)
107+
},
108+
'__c = __a * __b',
109+
{'__c': dace.Memlet.simple("_c", c_indices, wcr_str='lambda x, y: x + y')},
110+
external_edges=True)
77111

78112
return sdfg
79113

@@ -441,20 +475,31 @@ def validate(self, sdfg, state):
441475
raise ValueError("Expected exactly one output from "
442476
"batched matrix-matrix product")
443477
out_memlet = out_edges[0].data
444-
# Function is symmetric, edge order does not matter
445-
if len(size0) not in [2, 3]:
446-
raise ValueError("Batched matrix-matrix product only supported on matrices")
447-
if len(size1) != 3:
448-
raise ValueError("Batched matrix-matrix product only supported on matrices")
478+
479+
# Both inputs must be at least 2D
480+
if len(size0) < 2:
481+
raise ValueError(f"First input must be at least 2D, got shape with {len(size0)} dimensions")
482+
if len(size1) < 2:
483+
raise ValueError(f"Second input must be at least 2D, got shape with {len(size1)} dimensions")
484+
485+
# At least one input must have batch dimensions (3D or higher) for batched operation
486+
if len(size0) <= 2 and len(size1) <= 2:
487+
raise ValueError(
488+
"Batched matrix-matrix product requires at least one input to have batch dimensions (3D or higher)")
489+
490+
# Validate K-dimension compatibility
449491
res = equal(size0[-1], size1[-2])
450492
if res is None:
451493
warnings.warn(
452494
f'First tensor\'s last mode {size0[-1]} and second tensor\'s second-last mode {size1[-2]} '
453495
f'may not match', UserWarning)
454496
elif not res:
455497
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
456-
if len(out_memlet.subset) != 3:
457-
raise ValueError("batched matrix-matrix product only supported on matrices")
498+
499+
# Output must have batch dimensions
500+
if len(out_memlet.subset) < 3:
501+
raise ValueError(
502+
f"Batched matrix-matrix product output must be at least 3D, got {len(out_memlet.subset)} dimensions")
458503

459504

460505
# Numpy replacement

0 commit comments

Comments
 (0)