Skip to content

Commit ead1fa6

Browse files
Fix StripMining and ApplyOnce (#2223)
1 parent aebd112 commit ead1fa6

File tree

6 files changed

+297
-68
lines changed

6 files changed

+297
-68
lines changed

dace/transformation/dataflow/strip_mining.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def _stripmine(self, sdfg: SDFG, graph: SDFGState, map_entry: nodes.MapEntry):
371371
new_map.schedule = map_entry.map.schedule
372372
map_entry.map.schedule = dtypes.ScheduleType.Sequential
373373

374+
# Get memlet paths of out edges, necessary to track other subsets
375+
# Adding new edges
376+
edge_to_src_memlet_paths = {out_edge: graph.memlet_path(out_edge) for out_edge in graph.out_edges(map_entry)}
377+
374378
# Redirect edges
375379
new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
376380
sdutil.change_edge_dest(graph, map_entry, new_map_entry)
@@ -381,16 +385,29 @@ def _stripmine(self, sdfg: SDFG, graph: SDFGState, map_entry: nodes.MapEntry):
381385
new_in_edges = dict()
382386
entry_in_conn = {}
383387
entry_out_conn = {}
384-
for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
388+
for out_edge in graph.out_edges(map_entry):
389+
_src, src_conn, _dst, _, memlet = out_edge
390+
# If we have <arr_subset> -> <scalar_data><sclar_subset> pattern (or any kind of other subset)
391+
# We need to expand on <arr_subset> and therefore need to get src path (through memlet paths)
392+
src_data_name = None
393+
if memlet.other_subset is None:
394+
src_data_name = memlet.data
395+
subset = memlet.subset
396+
else:
397+
src_edge = edge_to_src_memlet_paths[out_edge][0]
398+
src_data_name = src_edge.src.data if isinstance(src_edge.src,
399+
dace.nodes.AccessNode) else src_edge.data.data
400+
subset = memlet.src_subset
401+
385402
if (src_conn is not None and src_conn[:4] == 'OUT_'
386-
and not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar)):
403+
and not isinstance(sdfg.arrays[src_data_name], dace.data.Scalar)):
387404
new_subset = calc_set_image(
388405
map_entry.map.params,
389406
map_entry.map.range,
390-
memlet.subset,
407+
subset,
391408
)
392409
conn = src_conn[4:]
393-
key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
410+
key = (src_data_name, 'IN_' + conn, 'OUT_' + conn)
394411
if key in new_in_edges.keys():
395412
old_subset = new_in_edges[key].subset
396413
new_in_edges[key].subset = calc_set_union(old_subset, new_subset)
@@ -420,6 +437,7 @@ def _stripmine(self, sdfg: SDFG, graph: SDFGState, map_entry: nodes.MapEntry):
420437
new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
421438
new_map_entry.out_connectors = entry_out_conn
422439
map_entry.in_connectors = entry_in_conn
440+
423441
for (_, in_conn, out_conn), memlet in new_in_edges.items():
424442
graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)
425443

dace/transformation/dataflow/tiling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,5 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
120120
mapcollapse.setup_match(sdfg, cfg_id, self.state_id, mapcollapse_subgraph, 0)
121121
mapcollapse.apply(graph, sdfg)
122122
last_map_entry = graph.in_edges(map_entry)[0].src
123+
123124
return last_map_entry

dace/transformation/passes/pattern_matching.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,17 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once:
244244
applied = True
245245
applied_anything = True
246246
break
247+
248+
# If apply once is set, applied should be forcefully set to True, once we have applied the transformation to all patterns
249+
if apply_once:
250+
break
251+
247252
if apply_once:
248253
break
249254
else:
250255
applied = True
251256
while applied:
252257
applied = False
253-
# Find and apply one of the chosen transformations
254258
for match in match_patterns(sdfg,
255259
permissive=self.permissive,
256260
patterns=xforms,
@@ -267,18 +271,14 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once:
267271
sdfg.validate()
268272
except InvalidSDFGError as err:
269273
if applied and match is not None:
270-
raise InvalidSDFGError("Validation failed after applying {}.".format(match.print_match(self)), self,
274+
raise InvalidSDFGError(f"Validation failed after applying {match.print_match(self)}.", self,
271275
match.state_id) from err
272276
else:
273277
raise err
274278

275-
if (len(applied_transformations) > 0
276-
and (self.progress or self.print_report or
277-
((self.progress is None or self.print_report is None) and Config.get_bool('debugprint')))):
278-
print('Applied {}.'.format(', '.join(['%d %s' % (len(v), k) for k, v in applied_transformations.items()])))
279-
280-
if len(applied_transformations) == 0: # Signal that no transformation was applied
279+
if len(applied_transformations) == 0:
281280
return None
281+
282282
return applied_transformations
283283

284284
def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, List[Any]]:

tests/transformations/add_threadblock_map_test.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dace
44
import pytest
55
import numpy
6+
import sys
67
from dace.transformation.dataflow.add_threadblock_map import AddThreadBlockMap
78

89
N = dace.symbol("N")
@@ -40,6 +41,108 @@ def elementwise_with_floor_div(A: dace.float64[512] @ dace.dtypes.StorageType.GP
4041
A[i] = 2.0 * B[i]
4142

4243

44+
def _get_sdfg_with_memlet_tree():
45+
sdfg = dace.SDFG("test")
46+
state = sdfg.add_state(is_start_block=True)
47+
48+
for aname in "ab":
49+
sdfg.add_array(
50+
aname,
51+
shape=(10, 2),
52+
dtype=dace.float64,
53+
storage=dace.dtypes.StorageType.GPU_Global,
54+
transient=False,
55+
)
56+
sdfg.add_scalar(
57+
"s",
58+
dtype=dace.float64,
59+
transient=True,
60+
)
61+
62+
a, b, s = (state.add_access(name) for name in "abs")
63+
me, mx = state.add_map("comp", ndrange={"__i": "0:10"}, schedule=dace.dtypes.ScheduleType.GPU_Device)
64+
tlet = state.add_tasklet(
65+
"tlet",
66+
inputs={"__in"},
67+
outputs={"__out"},
68+
code="__out = __in + 1.0",
69+
)
70+
71+
state.add_edge(
72+
a,
73+
None,
74+
me,
75+
"IN_a1",
76+
dace.Memlet("a[0:10, 0]"),
77+
)
78+
state.add_edge(
79+
me,
80+
"OUT_a1",
81+
tlet,
82+
"__in",
83+
dace.Memlet("a[__i, 0]"),
84+
)
85+
me.add_scope_connectors("a1")
86+
87+
state.add_edge(
88+
tlet,
89+
"__out",
90+
mx,
91+
"IN_b1",
92+
dace.Memlet("b[__i, 0]"),
93+
)
94+
state.add_edge(
95+
mx,
96+
"OUT_b1",
97+
b,
98+
None,
99+
dace.Memlet("b[0:10, 0]"),
100+
)
101+
mx.add_scope_connectors("b1")
102+
103+
state.add_edge(
104+
me,
105+
# It is also important that we read from the same as the tasklet.
106+
"OUT_a1",
107+
s,
108+
None,
109+
# According to my understanding the error is here, that the data of this
110+
# Memlet refers to `s` instead of `a` as the outer data does.
111+
dace.Memlet("s[0] -> [__i, 0]"),
112+
)
113+
114+
state.add_edge(
115+
s,
116+
None,
117+
mx,
118+
"IN_b2",
119+
dace.Memlet("b[__i, 1] -> [0]"),
120+
)
121+
state.add_edge(
122+
mx,
123+
"OUT_b2",
124+
b,
125+
None,
126+
dace.Memlet("b[0:10, 1]"),
127+
)
128+
mx.add_scope_connectors("b2")
129+
130+
sdfg.validate()
131+
return sdfg
132+
133+
134+
def test_memlet_tree():
135+
sdfg = _get_sdfg_with_memlet_tree()
136+
137+
sdfg.apply_transformations_once_everywhere(
138+
AddThreadBlockMap,
139+
validate=True,
140+
validate_all=True,
141+
)
142+
143+
sdfg.validate()
144+
145+
43146
def _run_and_compare(prog, A_host, B_host, constants=None):
44147
"""Run SDFG with and without AddThreadBlockMap and compare results."""
45148
import cupy
@@ -122,5 +225,7 @@ def test_elementwise_with_floor_div():
122225
if __name__ == "__main__":
123226
test_elementwise_constexpr_size()
124227
test_elementwise_small_constexpr_size()
125-
test_elementwise_symbolic()
228+
for symbol_param in symbol_params:
229+
test_elementwise_symbolic(symbol_param)
126230
test_elementwise_with_floor_div()
231+
test_memlet_tree()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
2+
from typing import List
3+
import dace
4+
import numpy as np
5+
from dace.transformation.dataflow import MapTiling
6+
import pytest
7+
8+
N = dace.symbol('N')
9+
10+
11+
def test_map_tiling_with_strides():
12+
s = 33
13+
14+
@dace.program
15+
def vector_copy_strides(A: dace.uint32[N], B: dace.uint32[N]):
16+
for i in dace.map[0:N:s] @ dace.dtypes.ScheduleType.CPU_Multicore:
17+
A[i] = B[i]
18+
19+
sdfg = vector_copy_strides.to_sdfg()
20+
sdfg.simplify()
21+
assert len(list(sdfg.all_states())) == 1
22+
23+
state = next(iter(sdfg.all_states()))
24+
state_nodes = state.nodes()
25+
map_entries: List[dace.nodes.MapEntry] = [n for n in state_nodes if isinstance(n, dace.nodes.MapEntry)]
26+
assert len(map_entries) == 1
27+
map_entry = map_entries[0]
28+
29+
tile_sizes = [32]
30+
MapTiling.apply_to(sdfg=sdfg,
31+
options={
32+
"prefix": "b",
33+
"tile_sizes": tile_sizes,
34+
"divides_evenly": False,
35+
"tile_trivial": True,
36+
"skew": False
37+
},
38+
map_entry=map_entry)
39+
inner_map_entry = map_entry
40+
outer_map_entry = state.entry_node(inner_map_entry)
41+
42+
b_i = dace.symbol("b_i")
43+
inner_rangelist = [(b_i, dace.symbolic.SymExpr("Min(N - 1, b_i + 32*33 - 1)"), 33)]
44+
outer_rangelist = [(0, N - 1, 32 * 33)]
45+
inner_range = dace.subsets.Range(inner_rangelist)
46+
outer_range = dace.subsets.Range(outer_rangelist)
47+
48+
sdfg.validate()
49+
50+
assert inner_map_entry.map.range == inner_range
51+
assert outer_map_entry.map.range == outer_range
52+
53+
54+
def _get_sdfg_with_memlet_tree():
55+
sdfg = dace.SDFG("test")
56+
state = sdfg.add_state(is_start_block=True)
57+
58+
for aname in "ab":
59+
sdfg.add_array(
60+
aname,
61+
shape=(10, 2),
62+
dtype=dace.float64,
63+
storage=dace.dtypes.StorageType.GPU_Global,
64+
transient=False,
65+
)
66+
sdfg.add_scalar(
67+
"s",
68+
dtype=dace.float64,
69+
transient=True,
70+
)
71+
72+
a, b, s = (state.add_access(name) for name in "abs")
73+
me, mx = state.add_map("comp", ndrange={"__i": "0:10"}, schedule=dace.dtypes.ScheduleType.GPU_Device)
74+
tlet = state.add_tasklet(
75+
"tlet",
76+
inputs={"__in"},
77+
outputs={"__out"},
78+
code="__out = __in + 1.0",
79+
)
80+
81+
state.add_edge(
82+
a,
83+
None,
84+
me,
85+
"IN_a1",
86+
dace.Memlet("a[0:10, 0]"),
87+
)
88+
state.add_edge(
89+
me,
90+
"OUT_a1",
91+
tlet,
92+
"__in",
93+
dace.Memlet("a[__i, 0]"),
94+
)
95+
me.add_scope_connectors("a1")
96+
97+
state.add_edge(
98+
tlet,
99+
"__out",
100+
mx,
101+
"IN_b1",
102+
dace.Memlet("b[__i, 0]"),
103+
)
104+
state.add_edge(
105+
mx,
106+
"OUT_b1",
107+
b,
108+
None,
109+
dace.Memlet("b[0:10, 0]"),
110+
)
111+
mx.add_scope_connectors("b1")
112+
113+
state.add_edge(
114+
me,
115+
# It is also important that we read from the same as the tasklet.
116+
"OUT_a1",
117+
s,
118+
None,
119+
# According to my understanding the error is here, that the data of this
120+
# Memlet refers to `s` instead of `a` as the outer data does.
121+
dace.Memlet("s[0] -> [__i, 0]"),
122+
)
123+
124+
state.add_edge(
125+
s,
126+
None,
127+
mx,
128+
"IN_b2",
129+
dace.Memlet("b[__i, 1] -> [0]"),
130+
)
131+
state.add_edge(
132+
mx,
133+
"OUT_b2",
134+
b,
135+
None,
136+
dace.Memlet("b[0:10, 1]"),
137+
)
138+
mx.add_scope_connectors("b2")
139+
140+
sdfg.validate()
141+
return sdfg
142+
143+
144+
def test_memlet_tree():
145+
sdfg = _get_sdfg_with_memlet_tree()
146+
sdfg.apply_transformations_once_everywhere(
147+
MapTiling,
148+
validate=True,
149+
validate_all=True,
150+
options={
151+
"tile_sizes": (2, ),
152+
},
153+
print_report=True,
154+
)
155+
sdfg.validate()
156+
157+
158+
if __name__ == '__main__':
159+
test_map_tiling_with_strides()
160+
test_memlet_tree()

0 commit comments

Comments
 (0)