11#include " error_checks.h"
22
3+ #include < functional>
34#include < map>
45#include < set>
5- #include < stack>
6- #include < functional>
6+ #include < tuple>
77
88#include " taco/type.h"
99#include " taco/index_notation/index_notation.h"
@@ -26,61 +26,24 @@ static vector<const AccessNode*> getAccessNodes(const IndexExpr& expr) {
2626 return readNodes;
2727}
2828
29- bool dimensionsTypecheck (const std::vector<IndexVar>& resultVars,
30- const IndexExpr& expr,
31- const Shape& shape) {
32-
33- std::map<IndexVar,Dimension> indexVarDims;
34- for (size_t mode = 0 ; mode < resultVars.size (); mode++) {
35- IndexVar var = resultVars[mode];
36- auto dimension = shape.getDimension (mode);
37- if (util::contains (indexVarDims,var) && indexVarDims.at (var) != dimension) {
38- return false ;
39- }
40- else {
41- indexVarDims.insert ({var, dimension});
42- }
43- }
44-
45- vector<const AccessNode*> readNodes = getAccessNodes (expr);
46- for (auto & readNode : readNodes) {
47- for (size_t mode = 0 ; mode < readNode->indexVars .size (); mode++) {
48- IndexVar var = readNode->indexVars [mode];
49- Dimension dimension =
50- readNode->tensorVar .getType ().getShape ().getDimension (mode);
51- if (util::contains (indexVarDims,var) &&
52- indexVarDims.at (var) != dimension) {
53- return false ;
54- }
55- else {
56- indexVarDims.insert ({var, dimension});
57- }
58- }
59- }
60-
61- return true ;
62- }
63-
6429static string addDimensionError (const IndexVar& var,
6530 Dimension dimension1, Dimension dimension2) {
6631 return " Index variable " + util::toString (var) + " is used to index "
6732 " modes of different dimensions (" + util::toString (dimension1) +
6833 " and " + util::toString (dimension2) + " )." ;
6934}
7035
71- std::string dimensionTypecheckErrors (const std::vector<IndexVar>& resultVars,
72- const IndexExpr& expr,
73- const Shape& shape) {
36+ std::pair< bool , string> dimensionsTypecheck (const std::vector<IndexVar>& resultVars,
37+ const IndexExpr& expr,
38+ const Shape& shape) {
7439 vector<string> errors;
75-
7640 std::map<IndexVar,Dimension> indexVarDims;
7741 for (size_t mode = 0 ; mode < resultVars.size (); mode++) {
7842 IndexVar var = resultVars[mode];
7943 auto dimension = shape.getDimension (mode);
8044 if (util::contains (indexVarDims,var) && indexVarDims.at (var) != dimension) {
8145 errors.push_back (addDimensionError (var, indexVarDims.at (var), dimension));
82- }
83- else {
46+ } else {
8447 indexVarDims.insert ({var, dimension});
8548 }
8649 }
@@ -89,20 +52,16 @@ std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
8952 for (auto & readNode : readNodes) {
9053 for (size_t mode = 0 ; mode < readNode->indexVars .size (); mode++) {
9154 IndexVar var = readNode->indexVars [mode];
92- Dimension dimension =
93- readNode->tensorVar .getType ().getShape ().getDimension (mode);
94- if (util::contains (indexVarDims,var) &&
95- indexVarDims.at (var) != dimension) {
96- errors.push_back (addDimensionError (var, indexVarDims.at (var),
97- dimension));
98- }
99- else {
55+ Dimension dimension = readNode->tensorVar .getType ().getShape ().getDimension (mode);
56+ if (util::contains (indexVarDims,var) && indexVarDims.at (var) != dimension) {
57+ errors.push_back (addDimensionError (var, indexVarDims.at (var), dimension));
58+ } else {
10059 indexVarDims.insert ({var, dimension});
10160 }
10261 }
10362 }
10463
105- return util::join (errors, " " );
64+ return std::make_pair (errors. empty (), util::join (errors, " " ) );
10665}
10766
10867static void addEdges (vector<IndexVar> indexVars, vector<int > modeOrdering,
0 commit comments