Skip to content

Commit c2e952f

Browse files
phschaadtbennun
andauthored
Backport of #2165 (#2166)
Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
1 parent 1810f49 commit c2e952f

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

dace/frontend/python/newast.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from dace.frontend.python import nested_call, replacements, preprocessing
2727
from dace.frontend.python.memlet_parser import DaceSyntaxError, parse_memlet, ParseMemlet, inner_eval_ast, MemletExpr
2828
from dace.sdfg import nodes
29-
from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states
29+
from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states, align_memlet
3030
from dace.memlet import Memlet
3131
from dace.properties import LambdaProperty, CodeBlock
3232
from dace.sdfg import SDFG, SDFGState
@@ -2774,7 +2774,12 @@ def _add_assignment(self,
27742774
memlet.other_subset = op_subset
27752775
if op:
27762776
memlet.wcr = LambdaProperty.from_string('lambda x, y: x {} y'.format(op))
2777-
state.add_nedge(op1, op2, memlet)
2777+
if isinstance(self.sdfg.arrays[target_name], data.Reference):
2778+
e = state.add_edge(op1, None, op2, 'set', memlet)
2779+
# Align memlet to referenced array
2780+
e.data = align_memlet(state, e, dst=False)
2781+
else:
2782+
state.add_nedge(op1, op2, memlet)
27782783
else:
27792784
memlet = Memlet("{a}[{s}]".format(a=target_name,
27802785
s=','.join(['__i%d' % i for i in range(len(target_subset))])))
@@ -3272,9 +3277,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign):
32723277
storage = dtypes.StorageType.Default
32733278
type_name = rname(node.annotation)
32743279
warnings.warn('typeclass {} is not supported'.format(type_name))
3275-
if node.value is None and dtype is not None: # Annotating type without assignment
3280+
if dtype is not None:
32763281
self.annotated_types[rname(node.target)] = dtype
3277-
return
3282+
if node.value is None: # Annotating type without assignment
3283+
return
32783284
results = self._visit_assign(node, node.target, None, dtype=dtype)
32793285
if storage != dtypes.StorageType.Default:
32803286
self.sdfg.arrays[results[0][0]].storage = storage
@@ -3403,6 +3409,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
34033409
true_name, new_data = self.sdfg.add_temp_transient([1], result_data.dtype)
34043410
self.variables[name] = true_name
34053411
defined_vars[name] = true_name
3412+
elif name in self.annotated_types and isinstance(self.annotated_types[name], data.Reference):
3413+
desc = copy.deepcopy(self.annotated_types[name])
3414+
desc.transient = True
3415+
true_name = self.sdfg.add_datadesc(name, desc, find_new_name=True)
3416+
self.variables[name] = true_name
3417+
defined_vars[name] = true_name
34063418
elif (not name.startswith('__return')
34073419
and (isinstance(result_data, data.View) or
34083420
(not result_data.transient and isinstance(result_data, data.Array)))):

tests/sdfg/reference_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,70 @@
1010
import networkx as nx
1111

1212

13+
def test_frontend_reference():
14+
N = dace.symbol('N')
15+
M = dace.symbol('M')
16+
mystruct = dace.data.Structure(members={
17+
"data": dace.data.Array(dace.float32, (N, M), strides=(1, N)),
18+
"arrA": dace.data.ArrayReference(dace.float32, (N, )),
19+
"arrB": dace.data.ArrayReference(dace.float32, (N, )),
20+
},
21+
name="MyStruct")
22+
23+
@dace.program
24+
def init_prog(mydat: mystruct, fill_value: int) -> None:
25+
mydat.arrA = mydat.data[:, 2]
26+
mydat.arrB = mydat.data[:, 0]
27+
28+
# loop over all arrays and initialize them with `fill_value`
29+
for index in range(M):
30+
mydat.data[:, index] = fill_value
31+
32+
# Initialize the two named ones by name
33+
mydat.arrA[:] = fill_value + 1
34+
mydat.arrB[:] = fill_value + 2
35+
36+
dat = np.zeros((10, 5), dtype=np.float32)
37+
inp_struct = mystruct.dtype._typeclass.as_ctypes()(data=dat.__array_interface__['data'][0])
38+
39+
func = init_prog.compile()
40+
func(mydat=inp_struct, fill_value=3, N=10, M=5)
41+
42+
assert np.allclose(dat[0, :], 5) and np.allclose(dat[1, :], 5)
43+
assert np.allclose(dat[2, :], 3) and np.allclose(dat[3, :], 3)
44+
assert np.allclose(dat[4, :], 4) and np.allclose(dat[5, :], 4)
45+
assert np.allclose(dat[6, :], 3) and np.allclose(dat[7, :], 3)
46+
assert np.allclose(dat[8, :], 3) and np.allclose(dat[9, :], 3)
47+
48+
49+
def test_type_annotation_reference():
50+
N = dace.symbol('N')
51+
52+
@dace.program
53+
def ref(A: dace.float64[N], B: dace.float64[N], T: dace.int32, out: dace.float64[N]):
54+
ref1: dace.data.ArrayReference(A.dtype, A.shape) = A
55+
ref2: dace.data.ArrayReference(A.dtype, A.shape) = B
56+
if T <= 0:
57+
out[:] = ref1[:] + 1
58+
else:
59+
out[:] = ref2[:] + 1
60+
61+
a = np.random.rand(20)
62+
a_verif = a.copy()
63+
b = np.random.rand(20)
64+
b_verif = b.copy()
65+
out = np.random.rand(20)
66+
out_verif = out.copy()
67+
68+
ref(a, b, 1, out, N=20)
69+
ref.f(a_verif, b_verif, 1, out_verif)
70+
assert np.allclose(out, out_verif)
71+
72+
ref(a, b, -1, out, N=20)
73+
ref.f(a_verif, b_verif, -1, out_verif)
74+
assert np.allclose(out, out_verif)
75+
76+
1377
def test_unset_reference():
1478
sdfg = dace.SDFG('tester')
1579
sdfg.add_reference('ref', [20], dace.float64)
@@ -683,6 +747,8 @@ def test_ref2view_reconnection():
683747

684748

685749
if __name__ == '__main__':
750+
test_frontend_reference()
751+
test_type_annotation_reference()
686752
test_unset_reference()
687753
test_reference_branch()
688754
test_reference_sources_pass()

0 commit comments

Comments
 (0)