Skip to content

Commit 1a88b3c

Browse files
romanctbennun
andauthored
Avoid usage of deprecated Indices class (#2292)
After a brief discussion `subsets.Indices` were deprecated last week with PR #2282. Since then, many Dace tests emit warnings because of remaining usage of `Indices` in the offset member functions of `subsets.Range`. This PR suggests to adapt `Range.from_indices()` to add support for a sequence of (symbolic) numbers or strings (as suggested in Mattermost). This allows to remove the remaining usage of `subsets.Indices` constructors in the DaCe codebase, which gets rid of a bunch of warnings emitted in test or upstream/user code. Only hickup that I had doing this was the function `_add_read_slice()` , called from `visit_Subscript()` of the `ProgramVisitor` in `newast.py` . That function would check for subsets to be either ranges or indices. And if subsets were indices, we'd go another way. That code path separation is apparently loosely tied to some other place in the codebase because we'd get errors if we were going the sub-optimal ranges-path with indices. I do now check if ranges are indices and set the flag accordingly. That seems to fix issues in tests. I've also checked (manually) all other cases where we'd go a different code path in case subsets are indices. There are some and the remaining ones all "upgrade" indices to ranges. They can be removed once we remove the deprecated `Indices` class. --------- Co-authored-by: Roman Cattaneo <> Co-authored-by: Tal Ben-Nun <tbennun@gmail.com>
1 parent 2ca64a4 commit 1a88b3c

29 files changed

+132
-160
lines changed

dace/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .version import __version__
44
from . import attr_enum
55
from .dtypes import *
6+
from . import serialize
67

78
# Import built-in hooks
89
from .builtin_hooks import *

dace/codegen/compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
compiles each target separately, links all targets to one binary, and
44
returns the corresponding CompiledSDFG object. """
55

6-
from __future__ import print_function
7-
86
import collections
97
import os
108
import six

dace/codegen/cppunparse.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
##########################################################################
7070
### END OF astunparse LICENSES
7171

72-
from __future__ import print_function, unicode_literals
7372
from functools import lru_cache
7473
import inspect
7574
import six

dace/codegen/targets/cpp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def copy_expr(
8181
if offset is None:
8282
s = None
8383
elif not isinstance(offset, subsets.Subset):
84-
s = subsets.Indices(offset)
84+
s = subsets.Range.from_indices(offset)
8585
else:
8686
s = offset
8787
o = None
@@ -529,7 +529,7 @@ def cpp_array_expr(sdfg,
529529
framecode: Optional['DaCeCodeGenerator'] = None):
530530
""" Converts an Indices/Range object to a C++ array access string. """
531531
subset = memlet.subset if not use_other_subset else memlet.other_subset
532-
s = subset if relative_offset else subsets.Indices(offset)
532+
s = subset if relative_offset else subsets.Range.from_indices(offset)
533533
o = offset if relative_offset else None
534534
desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array)
535535
offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices)
@@ -578,7 +578,7 @@ def cpp_ptr_expr(sdfg,
578578
codegen: 'TargetCodeGenerator' = None):
579579
""" Converts a memlet to a C++ pointer expression. """
580580
subset = memlet.subset if not use_other_subset else memlet.other_subset
581-
s = subset if relative_offset else subsets.Indices(offset)
581+
s = subset if relative_offset else subsets.Range.from_indices(offset)
582582
o = offset if relative_offset else None
583583
desc = sdfg.arrays[memlet.data]
584584
if isinstance(indices, str):

dace/frontend/operations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
2-
from __future__ import print_function
32
from functools import partial
43
from itertools import chain, repeat
54

dace/frontend/python/memlet_parser.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def pyexpr_to_symbolic(defined_arrays_and_symbols: Dict[str, Any], expr_ast: ast
6262
def _ndslice_to_subset(ndslice):
6363
is_tuple = [isinstance(x, tuple) for x in ndslice]
6464
if not any(is_tuple):
65-
return subsets.Indices(ndslice)
66-
else:
67-
if not all(is_tuple):
68-
# If a mix of ranges and indices is found, convert to range
69-
for i in range(len(ndslice)):
70-
if not is_tuple[i]:
71-
ndslice[i] = (ndslice[i], ndslice[i], 1)
72-
return subsets.Range(ndslice)
65+
return subsets.Range.from_indices(ndslice)
66+
67+
if not all(is_tuple):
68+
# If a mix of ranges and indices is found, convert to range
69+
for i in range(len(ndslice)):
70+
if not is_tuple[i]:
71+
ndslice[i] = (ndslice[i], ndslice[i], 1)
72+
return subsets.Range(ndslice)
7373

7474

7575
def _parse_dim_atom(das, atom):

dace/frontend/python/newast.py

Lines changed: 93 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections import OrderedDict
44
import copy
55
import itertools
6-
import networkx as nx
76
import re
87
import sys
98
import time
@@ -15,7 +14,6 @@
1514

1615
import dace
1716
from dace import data, dtypes, subsets, symbolic, sdfg as sd
18-
from dace import sourcemap
1917
from dace.config import Config
2018
from dace.frontend.common import op_repository as oprepo
2119
from dace.frontend.python import astutils
@@ -475,7 +473,7 @@ def add_indirection_subgraph(sdfg: SDFG,
475473
arr = sdfg.arrays[arrname]
476474
subset = subsets.Range.from_array(arr)
477475
else:
478-
subset = subsets.Indices(access)
476+
subset = subsets.Range.from_indices(access)
479477
# Memlet to load the indirection index
480478
indexMemlet = Memlet.simple(arrname, subset)
481479
input_index_memlets.append(indexMemlet)
@@ -1771,8 +1769,8 @@ def _inject_consume_memlets(self, dec, entry, inputs, internal_node, sdfg: SDFG,
17711769
# Inject element to internal SDFG arrays
17721770
ntrans = f'consume_{stream_name}'
17731771
ntrans, _ = sdfg.add_array(ntrans, [1], self.sdfg.arrays[stream_name].dtype, find_new_name=True)
1774-
internal_memlet = dace.Memlet.simple(ntrans, subsets.Indices([0]))
1775-
external_memlet = dace.Memlet.simple(stream_name, subsets.Indices([0]), num_accesses=-1)
1772+
internal_memlet = dace.Memlet.simple(ntrans, subsets.Range.from_indices([0]))
1773+
external_memlet = dace.Memlet.simple(stream_name, subsets.Range.from_indices([0]), num_accesses=-1)
17761774

17771775
# Inject to internal tasklet
17781776
if not dec.endswith('scope'):
@@ -5290,82 +5288,97 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr):
52905288
# Make copy slicing state
52915289
rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo)
52925290
return self._array_indirection_subgraph(rnode, expr)
5291+
5292+
is_index = False
5293+
if isinstance(expr.subset, subsets.Indices):
5294+
is_index = True
5295+
other_subset = subsets.Range([(i, i, 1) for i in expr.subset])
52935296
else:
5294-
is_index = False
5295-
if isinstance(expr.subset, subsets.Indices):
5296-
is_index = True
5297-
other_subset = subsets.Range([(i, i, 1) for i in expr.subset])
5298-
else:
5299-
other_subset = copy.deepcopy(expr.subset)
5300-
strides = list(arrobj.strides)
5301-
5302-
# Make new axes and squeeze for scalar subsets (as per numpy behavior)
5303-
# For example: A[0, np.newaxis, 5:7] results in a 1x2 ndarray
5304-
new_axes = []
5305-
if expr.new_axes:
5306-
new_axes = other_subset.unsqueeze(expr.new_axes)
5307-
for i in new_axes:
5308-
strides.insert(i, 1)
5309-
length = len(other_subset)
5310-
nsqz = other_subset.squeeze(ignore_indices=new_axes)
5311-
sqz = [i for i in range(length) if i not in nsqz]
5312-
for i in reversed(sqz):
5313-
strides.pop(i)
5314-
if not strides:
5315-
strides = None
5316-
5317-
if is_index:
5318-
tmp = self.get_target_name(default=f'{array}_index')
5319-
tmp, tmparr = self.sdfg.add_scalar(tmp,
5320-
arrobj.dtype,
5321-
arrobj.storage,
5322-
transient=True,
5323-
find_new_name=True)
5324-
else:
5325-
for i in range(len(other_subset.ranges)):
5326-
rb, re, rs = other_subset.ranges[i]
5327-
if (rs < 0) == True:
5328-
raise DaceSyntaxError(
5329-
self, node, 'Negative strides are not supported in subscripts. '
5330-
'Please use a Map scope to express this operation.')
5331-
re = re - rb
5332-
rb = 0
5333-
if rs != 1:
5334-
# NOTE: We use the identity floor(A/B) = ceiling((A + 1) / B) - 1
5335-
# because Range.size() uses the ceiling method and that way we avoid
5336-
# false negatives when testing the equality of data shapes.
5337-
# re = re // rs
5338-
re = sympy.ceiling((re + 1) / rs) - 1
5339-
strides[i] *= rs
5340-
rs = 1
5341-
other_subset.ranges[i] = (rb, re, rs)
5342-
5343-
tmp, tmparr = self.sdfg.add_view(array,
5344-
other_subset.size(),
5345-
arrobj.dtype,
5346-
storage=arrobj.storage,
5347-
strides=strides,
5348-
find_new_name=True)
5349-
self.views[tmp] = (array,
5350-
Memlet(data=array,
5351-
subset=str(expr.subset),
5352-
other_subset=str(other_subset),
5353-
volume=expr.accesses,
5354-
wcr=expr.wcr))
5355-
self.variables[tmp] = tmp
5356-
if not isinstance(tmparr, data.View):
5357-
rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo)
5358-
wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo)
5359-
# NOTE: We convert the subsets to string because keeping the original symbolic information causes
5360-
# equality check failures, e.g., in LoopToMap.
5361-
self.current_state.add_nedge(
5362-
rnode, wnode,
5363-
Memlet(data=array,
5364-
subset=str(expr.subset),
5365-
other_subset=str(other_subset),
5366-
volume=expr.accesses,
5367-
wcr=expr.wcr))
5368-
return tmp
5297+
5298+
def range_is_index(range: subsets.Range) -> bool:
5299+
"""
5300+
Check if the given subset range is an index.
5301+
5302+
Conditions for an index are as follows:
5303+
- tile_size of each range has to be 1
5304+
- the range increment has to be 1
5305+
- start/stop element of the range have to be equal
5306+
"""
5307+
for r, t in zip(range.ranges, range.tile_sizes):
5308+
if t != 1 or r[2] != 1 or r[0] != r[1]:
5309+
return False
5310+
return True
5311+
5312+
# We also check the type of the slice attribute of the node
5313+
# in order to distinguish between A[0] and A[0:1], which are semantically different in numpy
5314+
# (the former is an index, the latter is a slice).
5315+
is_index = range_is_index(expr.subset) and not isinstance(node.slice, ast.Slice)
5316+
other_subset = copy.deepcopy(expr.subset)
5317+
strides = list(arrobj.strides)
5318+
5319+
# Make new axes and squeeze for scalar subsets (as per numpy behavior)
5320+
# For example: A[0, np.newaxis, 5:7] results in a 1x2 ndarray
5321+
new_axes = []
5322+
if expr.new_axes:
5323+
new_axes = other_subset.unsqueeze(expr.new_axes)
5324+
for i in new_axes:
5325+
strides.insert(i, 1)
5326+
length = len(other_subset)
5327+
nsqz = other_subset.squeeze(ignore_indices=new_axes)
5328+
sqz = [i for i in range(length) if i not in nsqz]
5329+
for i in reversed(sqz):
5330+
strides.pop(i)
5331+
if not strides:
5332+
strides = None
5333+
5334+
if is_index:
5335+
tmp = self.get_target_name(default=f'{array}_index')
5336+
tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True, find_new_name=True)
5337+
else:
5338+
for i in range(len(other_subset.ranges)):
5339+
rb, re, rs = other_subset.ranges[i]
5340+
if (rs < 0) == True:
5341+
raise DaceSyntaxError(
5342+
self, node, 'Negative strides are not supported in subscripts. '
5343+
'Please use a Map scope to express this operation.')
5344+
re = re - rb
5345+
rb = 0
5346+
if rs != 1:
5347+
# NOTE: We use the identity floor(A/B) = ceiling((A + 1) / B) - 1
5348+
# because Range.size() uses the ceiling method and that way we avoid
5349+
# false negatives when testing the equality of data shapes.
5350+
# re = re // rs
5351+
re = sympy.ceiling((re + 1) / rs) - 1
5352+
strides[i] *= rs
5353+
rs = 1
5354+
other_subset.ranges[i] = (rb, re, rs)
5355+
5356+
tmp, tmparr = self.sdfg.add_view(array,
5357+
other_subset.size(),
5358+
arrobj.dtype,
5359+
storage=arrobj.storage,
5360+
strides=strides,
5361+
find_new_name=True)
5362+
self.views[tmp] = (array,
5363+
Memlet(data=array,
5364+
subset=str(expr.subset),
5365+
other_subset=str(other_subset),
5366+
volume=expr.accesses,
5367+
wcr=expr.wcr))
5368+
self.variables[tmp] = tmp
5369+
if not isinstance(tmparr, data.View):
5370+
rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo)
5371+
wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo)
5372+
# NOTE: We convert the subsets to string because keeping the original symbolic information causes
5373+
# equality check failures, e.g., in LoopToMap.
5374+
self.current_state.add_nedge(
5375+
rnode, wnode,
5376+
Memlet(data=array,
5377+
subset=str(expr.subset),
5378+
other_subset=str(other_subset),
5379+
volume=expr.accesses,
5380+
wcr=expr.wcr))
5381+
return tmp
53695382

53705383
def _parse_subscript_slice(self,
53715384
s: ast.AST,

dace/frontend/python/wrappers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
22
""" Types and wrappers used in DaCe's Python frontend. """
3-
from __future__ import print_function
43
import numpy
54
import itertools
65
from collections import deque

dace/libraries/mpi/nodes/redistribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def expansion(node, parent_state, parent_sdfg):
2424

2525
inp_symbols = [symbolic.symbol(f"__inp_s{i}") for i in range(len(inp_buffer.shape))]
2626
out_symbols = [symbolic.symbol(f"__out_s{i}") for i in range(len(out_buffer.shape))]
27-
inp_subset = subsets.Indices(inp_symbols)
28-
out_subset = subsets.Indices(out_symbols)
27+
inp_subset = subsets.Range.from_indices(inp_symbols)
28+
out_subset = subsets.Range.from_indices(out_symbols)
2929
inp_offset = cpp.cpp_offset_expr(inp_buffer, inp_subset)
3030
out_offset = cpp.cpp_offset_expr(out_buffer, out_subset)
3131
print(inp_offset)

dace/subsets.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
22
import dace.serialize
3-
from dace import data, symbolic, dtypes
4-
import re
3+
from dace import symbolic
54
import sympy as sp
65
from functools import reduce
7-
import sympy.core.sympify
86
from typing import List, Optional, Sequence, Set, Union
97
import warnings
108
from dace.config import Config
@@ -322,8 +320,12 @@ def __init__(self, ranges):
322320
self.tile_sizes = parsed_tiles
323321

324322
@staticmethod
325-
def from_indices(indices: 'Indices'):
326-
return Range([(i, i, 1) for i in indices.indices])
323+
def from_indices(indices: Union["Indices", Sequence[int | str | symbolic.SymbolicType]]):
324+
if isinstance(indices, Indices):
325+
return Range([(i, i, 1) for i in indices.indices])
326+
327+
indices = [symbolic.pystr_to_symbolic(i) for i in indices]
328+
return Range([(i, i, 1) for i in indices])
327329

328330
def to_json(self):
329331
ret = []
@@ -498,9 +500,9 @@ def offset(self, other, negative, indices=None, offset_end=True):
498500
return
499501
if not isinstance(other, Subset):
500502
if isinstance(other, (list, tuple)):
501-
other = Indices(other)
503+
other = Range.from_indices(other)
502504
else:
503-
other = Indices([other for _ in self.ranges])
505+
other = Range.from_indices([other for _ in self.ranges])
504506
mult = -1 if negative else 1
505507
if indices is None:
506508
indices = set(range(len(self.ranges)))
@@ -516,9 +518,9 @@ def offset_new(self, other, negative, indices=None, offset_end=True):
516518
return Range(self.ranges)
517519
if not isinstance(other, Subset):
518520
if isinstance(other, (list, tuple)):
519-
other = Indices(other)
521+
other = Range.from_indices(other)
520522
else:
521-
other = Indices([other for _ in self.ranges])
523+
other = Range.from_indices([other for _ in self.ranges])
522524
mult = -1 if negative else 1
523525
if indices is None:
524526
indices = set(range(len(self.ranges)))
@@ -716,7 +718,7 @@ def from_string(string):
716718
tsize = tokens[3]
717719
else:
718720
tsize = 1
719-
except sympy.SympifyError:
721+
except sp.SympifyError:
720722
raise SyntaxError("Invalid range: {}".format(string))
721723
# Append range
722724
ranges.append((begin, end, step, tsize))
@@ -1148,9 +1150,9 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
11481150
elif len(brb.free_symbols) == 0:
11491151
minrb = brb
11501152
else:
1151-
minrb = sympy.Min(arb, brb)
1153+
minrb = sp.Min(arb, brb)
11521154
else:
1153-
minrb = sympy.Min(arb, brb)
1155+
minrb = sp.Min(arb, brb)
11541156

11551157
try:
11561158
maxre = max(are, bre)
@@ -1161,9 +1163,9 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
11611163
elif len(bre.free_symbols) == 0:
11621164
maxre = are
11631165
else:
1164-
maxre = sympy.Max(are, bre)
1166+
maxre = sp.Max(are, bre)
11651167
else:
1166-
maxre = sympy.Max(are, bre)
1168+
maxre = sp.Max(are, bre)
11671169

11681170
result.append((minrb, maxre, 1))
11691171

0 commit comments

Comments
 (0)