|
3 | 3 | from collections import OrderedDict |
4 | 4 | import copy |
5 | 5 | import itertools |
6 | | -import networkx as nx |
7 | 6 | import re |
8 | 7 | import sys |
9 | 8 | import time |
|
15 | 14 |
|
16 | 15 | import dace |
17 | 16 | from dace import data, dtypes, subsets, symbolic, sdfg as sd |
18 | | -from dace import sourcemap |
19 | 17 | from dace.config import Config |
20 | 18 | from dace.frontend.common import op_repository as oprepo |
21 | 19 | from dace.frontend.python import astutils |
@@ -475,7 +473,7 @@ def add_indirection_subgraph(sdfg: SDFG, |
475 | 473 | arr = sdfg.arrays[arrname] |
476 | 474 | subset = subsets.Range.from_array(arr) |
477 | 475 | else: |
478 | | - subset = subsets.Indices(access) |
| 476 | + subset = subsets.Range.from_indices(access) |
479 | 477 | # Memlet to load the indirection index |
480 | 478 | indexMemlet = Memlet.simple(arrname, subset) |
481 | 479 | input_index_memlets.append(indexMemlet) |
@@ -1771,8 +1769,8 @@ def _inject_consume_memlets(self, dec, entry, inputs, internal_node, sdfg: SDFG, |
1771 | 1769 | # Inject element to internal SDFG arrays |
1772 | 1770 | ntrans = f'consume_{stream_name}' |
1773 | 1771 | ntrans, _ = sdfg.add_array(ntrans, [1], self.sdfg.arrays[stream_name].dtype, find_new_name=True) |
1774 | | - internal_memlet = dace.Memlet.simple(ntrans, subsets.Indices([0])) |
1775 | | - external_memlet = dace.Memlet.simple(stream_name, subsets.Indices([0]), num_accesses=-1) |
| 1772 | + internal_memlet = dace.Memlet.simple(ntrans, subsets.Range.from_indices([0])) |
| 1773 | + external_memlet = dace.Memlet.simple(stream_name, subsets.Range.from_indices([0]), num_accesses=-1) |
1776 | 1774 |
|
1777 | 1775 | # Inject to internal tasklet |
1778 | 1776 | if not dec.endswith('scope'): |
@@ -5290,82 +5288,97 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): |
5290 | 5288 | # Make copy slicing state |
5291 | 5289 | rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) |
5292 | 5290 | return self._array_indirection_subgraph(rnode, expr) |
| 5291 | + |
| 5292 | + is_index = False |
| 5293 | + if isinstance(expr.subset, subsets.Indices): |
| 5294 | + is_index = True |
| 5295 | + other_subset = subsets.Range([(i, i, 1) for i in expr.subset]) |
5293 | 5296 | else: |
5294 | | - is_index = False |
5295 | | - if isinstance(expr.subset, subsets.Indices): |
5296 | | - is_index = True |
5297 | | - other_subset = subsets.Range([(i, i, 1) for i in expr.subset]) |
5298 | | - else: |
5299 | | - other_subset = copy.deepcopy(expr.subset) |
5300 | | - strides = list(arrobj.strides) |
5301 | | - |
5302 | | - # Make new axes and squeeze for scalar subsets (as per numpy behavior) |
5303 | | - # For example: A[0, np.newaxis, 5:7] results in a 1x2 ndarray |
5304 | | - new_axes = [] |
5305 | | - if expr.new_axes: |
5306 | | - new_axes = other_subset.unsqueeze(expr.new_axes) |
5307 | | - for i in new_axes: |
5308 | | - strides.insert(i, 1) |
5309 | | - length = len(other_subset) |
5310 | | - nsqz = other_subset.squeeze(ignore_indices=new_axes) |
5311 | | - sqz = [i for i in range(length) if i not in nsqz] |
5312 | | - for i in reversed(sqz): |
5313 | | - strides.pop(i) |
5314 | | - if not strides: |
5315 | | - strides = None |
5316 | | - |
5317 | | - if is_index: |
5318 | | - tmp = self.get_target_name(default=f'{array}_index') |
5319 | | - tmp, tmparr = self.sdfg.add_scalar(tmp, |
5320 | | - arrobj.dtype, |
5321 | | - arrobj.storage, |
5322 | | - transient=True, |
5323 | | - find_new_name=True) |
5324 | | - else: |
5325 | | - for i in range(len(other_subset.ranges)): |
5326 | | - rb, re, rs = other_subset.ranges[i] |
5327 | | - if (rs < 0) == True: |
5328 | | - raise DaceSyntaxError( |
5329 | | - self, node, 'Negative strides are not supported in subscripts. ' |
5330 | | - 'Please use a Map scope to express this operation.') |
5331 | | - re = re - rb |
5332 | | - rb = 0 |
5333 | | - if rs != 1: |
5334 | | - # NOTE: We use the identity floor(A/B) = ceiling((A + 1) / B) - 1 |
5335 | | - # because Range.size() uses the ceiling method and that way we avoid |
5336 | | - # false negatives when testing the equality of data shapes. |
5337 | | - # re = re // rs |
5338 | | - re = sympy.ceiling((re + 1) / rs) - 1 |
5339 | | - strides[i] *= rs |
5340 | | - rs = 1 |
5341 | | - other_subset.ranges[i] = (rb, re, rs) |
5342 | | - |
5343 | | - tmp, tmparr = self.sdfg.add_view(array, |
5344 | | - other_subset.size(), |
5345 | | - arrobj.dtype, |
5346 | | - storage=arrobj.storage, |
5347 | | - strides=strides, |
5348 | | - find_new_name=True) |
5349 | | - self.views[tmp] = (array, |
5350 | | - Memlet(data=array, |
5351 | | - subset=str(expr.subset), |
5352 | | - other_subset=str(other_subset), |
5353 | | - volume=expr.accesses, |
5354 | | - wcr=expr.wcr)) |
5355 | | - self.variables[tmp] = tmp |
5356 | | - if not isinstance(tmparr, data.View): |
5357 | | - rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) |
5358 | | - wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) |
5359 | | - # NOTE: We convert the subsets to string because keeping the original symbolic information causes |
5360 | | - # equality check failures, e.g., in LoopToMap. |
5361 | | - self.current_state.add_nedge( |
5362 | | - rnode, wnode, |
5363 | | - Memlet(data=array, |
5364 | | - subset=str(expr.subset), |
5365 | | - other_subset=str(other_subset), |
5366 | | - volume=expr.accesses, |
5367 | | - wcr=expr.wcr)) |
5368 | | - return tmp |
| 5297 | + |
| 5298 | + def range_is_index(range: subsets.Range) -> bool: |
| 5299 | + """ |
| 5300 | + Check if the given subset range is an index. |
| 5301 | +
|
| 5302 | + Conditions for an index are as follows: |
| 5303 | + - tile_size of each range has to be 1 |
| 5304 | + - the range increment has to be 1 |
| 5305 | + - start/stop element of the range have to be equal |
| 5306 | + """ |
| 5307 | + for r, t in zip(range.ranges, range.tile_sizes): |
| 5308 | + if t != 1 or r[2] != 1 or r[0] != r[1]: |
| 5309 | + return False |
| 5310 | + return True |
| 5311 | + |
| 5312 | + # We also check the type of the slice attribute of the node |
| 5313 | + # in order to distinguish between A[0] and A[0:1], which are semantically different in numpy |
| 5314 | + # (the former is an index, the latter is a slice). |
| 5315 | + is_index = range_is_index(expr.subset) and not isinstance(node.slice, ast.Slice) |
| 5316 | + other_subset = copy.deepcopy(expr.subset) |
| 5317 | + strides = list(arrobj.strides) |
| 5318 | + |
| 5319 | + # Make new axes and squeeze for scalar subsets (as per numpy behavior) |
| 5320 | + # For example: A[0, np.newaxis, 5:7] results in a 1x2 ndarray |
| 5321 | + new_axes = [] |
| 5322 | + if expr.new_axes: |
| 5323 | + new_axes = other_subset.unsqueeze(expr.new_axes) |
| 5324 | + for i in new_axes: |
| 5325 | + strides.insert(i, 1) |
| 5326 | + length = len(other_subset) |
| 5327 | + nsqz = other_subset.squeeze(ignore_indices=new_axes) |
| 5328 | + sqz = [i for i in range(length) if i not in nsqz] |
| 5329 | + for i in reversed(sqz): |
| 5330 | + strides.pop(i) |
| 5331 | + if not strides: |
| 5332 | + strides = None |
| 5333 | + |
| 5334 | + if is_index: |
| 5335 | + tmp = self.get_target_name(default=f'{array}_index') |
| 5336 | + tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True, find_new_name=True) |
| 5337 | + else: |
| 5338 | + for i in range(len(other_subset.ranges)): |
| 5339 | + rb, re, rs = other_subset.ranges[i] |
| 5340 | + if (rs < 0) == True: |
| 5341 | + raise DaceSyntaxError( |
| 5342 | + self, node, 'Negative strides are not supported in subscripts. ' |
| 5343 | + 'Please use a Map scope to express this operation.') |
| 5344 | + re = re - rb |
| 5345 | + rb = 0 |
| 5346 | + if rs != 1: |
| 5347 | + # NOTE: We use the identity floor(A/B) = ceiling((A + 1) / B) - 1 |
| 5348 | + # because Range.size() uses the ceiling method and that way we avoid |
| 5349 | + # false negatives when testing the equality of data shapes. |
| 5350 | + # re = re // rs |
| 5351 | + re = sympy.ceiling((re + 1) / rs) - 1 |
| 5352 | + strides[i] *= rs |
| 5353 | + rs = 1 |
| 5354 | + other_subset.ranges[i] = (rb, re, rs) |
| 5355 | + |
| 5356 | + tmp, tmparr = self.sdfg.add_view(array, |
| 5357 | + other_subset.size(), |
| 5358 | + arrobj.dtype, |
| 5359 | + storage=arrobj.storage, |
| 5360 | + strides=strides, |
| 5361 | + find_new_name=True) |
| 5362 | + self.views[tmp] = (array, |
| 5363 | + Memlet(data=array, |
| 5364 | + subset=str(expr.subset), |
| 5365 | + other_subset=str(other_subset), |
| 5366 | + volume=expr.accesses, |
| 5367 | + wcr=expr.wcr)) |
| 5368 | + self.variables[tmp] = tmp |
| 5369 | + if not isinstance(tmparr, data.View): |
| 5370 | + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) |
| 5371 | + wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) |
| 5372 | + # NOTE: We convert the subsets to string because keeping the original symbolic information causes |
| 5373 | + # equality check failures, e.g., in LoopToMap. |
| 5374 | + self.current_state.add_nedge( |
| 5375 | + rnode, wnode, |
| 5376 | + Memlet(data=array, |
| 5377 | + subset=str(expr.subset), |
| 5378 | + other_subset=str(other_subset), |
| 5379 | + volume=expr.accesses, |
| 5380 | + wcr=expr.wcr)) |
| 5381 | + return tmp |
5369 | 5382 |
|
5370 | 5383 | def _parse_subscript_slice(self, |
5371 | 5384 | s: ast.AST, |
|
0 commit comments