|
10 | 10 | import networkx as nx |
11 | 11 |
|
12 | 12 |
|
| 13 | +def test_frontend_reference(): |
| 14 | + N = dace.symbol('N') |
| 15 | + M = dace.symbol('M') |
| 16 | + mystruct = dace.data.Structure(members={ |
| 17 | + "data": dace.data.Array(dace.float32, (N, M), strides=(1, N)), |
| 18 | + "arrA": dace.data.ArrayReference(dace.float32, (N, )), |
| 19 | + "arrB": dace.data.ArrayReference(dace.float32, (N, )), |
| 20 | + }, |
| 21 | + name="MyStruct") |
| 22 | + |
| 23 | + @dace.program |
| 24 | + def init_prog(mydat: mystruct, fill_value: int) -> None: |
| 25 | + mydat.arrA = mydat.data[:, 2] |
| 26 | + mydat.arrB = mydat.data[:, 0] |
| 27 | + |
| 28 | + # loop over all arrays and initialize them with `fill_value` |
| 29 | + for index in range(M): |
| 30 | + mydat.data[:, index] = fill_value |
| 31 | + |
| 32 | + # Initialize the two named ones by name |
| 33 | + mydat.arrA[:] = fill_value + 1 |
| 34 | + mydat.arrB[:] = fill_value + 2 |
| 35 | + |
| 36 | + dat = np.zeros((10, 5), dtype=np.float32) |
| 37 | + inp_struct = mystruct.dtype._typeclass.as_ctypes()(data=dat.__array_interface__['data'][0]) |
| 38 | + |
| 39 | + func = init_prog.compile() |
| 40 | + func(mydat=inp_struct, fill_value=3, N=10, M=5) |
| 41 | + |
| 42 | + assert np.allclose(dat[0, :], 5) and np.allclose(dat[1, :], 5) |
| 43 | + assert np.allclose(dat[2, :], 3) and np.allclose(dat[3, :], 3) |
| 44 | + assert np.allclose(dat[4, :], 4) and np.allclose(dat[5, :], 4) |
| 45 | + assert np.allclose(dat[6, :], 3) and np.allclose(dat[7, :], 3) |
| 46 | + assert np.allclose(dat[8, :], 3) and np.allclose(dat[9, :], 3) |
| 47 | + |
| 48 | + |
| 49 | +def test_type_annotation_reference(): |
| 50 | + N = dace.symbol('N') |
| 51 | + |
| 52 | + @dace.program |
| 53 | + def ref(A: dace.float64[N], B: dace.float64[N], T: dace.int32, out: dace.float64[N]): |
| 54 | + ref1: dace.data.ArrayReference(A.dtype, A.shape) = A |
| 55 | + ref2: dace.data.ArrayReference(A.dtype, A.shape) = B |
| 56 | + if T <= 0: |
| 57 | + out[:] = ref1[:] + 1 |
| 58 | + else: |
| 59 | + out[:] = ref2[:] + 1 |
| 60 | + |
| 61 | + a = np.random.rand(20) |
| 62 | + a_verif = a.copy() |
| 63 | + b = np.random.rand(20) |
| 64 | + b_verif = b.copy() |
| 65 | + out = np.random.rand(20) |
| 66 | + out_verif = out.copy() |
| 67 | + |
| 68 | + ref(a, b, 1, out, N=20) |
| 69 | + ref.f(a_verif, b_verif, 1, out_verif) |
| 70 | + assert np.allclose(out, out_verif) |
| 71 | + |
| 72 | + ref(a, b, -1, out, N=20) |
| 73 | + ref.f(a_verif, b_verif, -1, out_verif) |
| 74 | + assert np.allclose(out, out_verif) |
| 75 | + |
| 76 | + |
13 | 77 | def test_unset_reference(): |
14 | 78 | sdfg = dace.SDFG('tester') |
15 | 79 | sdfg.add_reference('ref', [20], dace.float64) |
@@ -683,6 +747,8 @@ def test_ref2view_reconnection(): |
683 | 747 |
|
684 | 748 |
|
685 | 749 | if __name__ == '__main__': |
| 750 | + test_frontend_reference() |
| 751 | + test_type_annotation_reference() |
686 | 752 | test_unset_reference() |
687 | 753 | test_reference_branch() |
688 | 754 | test_reference_sources_pass() |
|
0 commit comments