Skip to content

Commit c7b5a10

Browse files
[v1 Maintenance] Back port MapExpansion fixes from v2 (#2257)
This PR brings the fixes to `MapExpansion` [transform done in v2](https://github.com/spcl/dace/blob/156567b1eea3b54cd3dda0b6d3f259995127be68/dace/transformation/dataflow/map_expansion.py#L135) back into v1.
1 parent aa5f8b5 commit c7b5a10

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

dace/transformation/dataflow/map_expansion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,14 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
136136
graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data))
137137
graph.remove_edge(edge)
138138

139-
if graph.in_degree(map_entry) == 0:
139+
if graph.in_degree(map_entry) == 0 or all(
140+
e.dst_conn is None or not e.dst_conn.startswith("IN_")
141+
for e in graph.in_edges(map_entry)):
140142
graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet())
141143
else:
142144
for edge in graph.in_edges(map_entry):
145+
if edge.dst_conn is None:
146+
continue
143147
if not edge.dst_conn.startswith("IN_"):
144148
continue
145149

0 commit comments

Comments
 (0)