Skip to content

Commit d9ec10f

Browse files
Merge pull request #357 from rohany/dedup-error
index_notation,error: deduplicate dimension checking routines
2 parents 859cfd2 + 14bda3d commit d9ec10f

File tree

3 files changed

+23
-67
lines changed

3 files changed

+23
-67
lines changed

src/error/error_checks.cpp

Lines changed: 11 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "error_checks.h"
22

3+
#include <functional>
34
#include <map>
45
#include <set>
5-
#include <stack>
6-
#include <functional>
6+
#include <tuple>
77

88
#include "taco/type.h"
99
#include "taco/index_notation/index_notation.h"
@@ -26,61 +26,24 @@ static vector<const AccessNode*> getAccessNodes(const IndexExpr& expr) {
2626
return readNodes;
2727
}
2828

29-
bool dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
30-
const IndexExpr& expr,
31-
const Shape& shape) {
32-
33-
std::map<IndexVar,Dimension> indexVarDims;
34-
for (size_t mode = 0; mode < resultVars.size(); mode++) {
35-
IndexVar var = resultVars[mode];
36-
auto dimension = shape.getDimension(mode);
37-
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
38-
return false;
39-
}
40-
else {
41-
indexVarDims.insert({var, dimension});
42-
}
43-
}
44-
45-
vector<const AccessNode*> readNodes = getAccessNodes(expr);
46-
for (auto& readNode : readNodes) {
47-
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
48-
IndexVar var = readNode->indexVars[mode];
49-
Dimension dimension =
50-
readNode->tensorVar.getType().getShape().getDimension(mode);
51-
if (util::contains(indexVarDims,var) &&
52-
indexVarDims.at(var) != dimension) {
53-
return false;
54-
}
55-
else {
56-
indexVarDims.insert({var, dimension});
57-
}
58-
}
59-
}
60-
61-
return true;
62-
}
63-
6429
static string addDimensionError(const IndexVar& var,
6530
Dimension dimension1, Dimension dimension2) {
6631
return "Index variable " + util::toString(var) + " is used to index "
6732
"modes of different dimensions (" + util::toString(dimension1) +
6833
" and " + util::toString(dimension2) + ").";
6934
}
7035

71-
std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
72-
const IndexExpr& expr,
73-
const Shape& shape) {
36+
std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
37+
const IndexExpr& expr,
38+
const Shape& shape) {
7439
vector<string> errors;
75-
7640
std::map<IndexVar,Dimension> indexVarDims;
7741
for (size_t mode = 0; mode < resultVars.size(); mode++) {
7842
IndexVar var = resultVars[mode];
7943
auto dimension = shape.getDimension(mode);
8044
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
8145
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
82-
}
83-
else {
46+
} else {
8447
indexVarDims.insert({var, dimension});
8548
}
8649
}
@@ -89,20 +52,16 @@ std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
8952
for (auto& readNode : readNodes) {
9053
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
9154
IndexVar var = readNode->indexVars[mode];
92-
Dimension dimension =
93-
readNode->tensorVar.getType().getShape().getDimension(mode);
94-
if (util::contains(indexVarDims,var) &&
95-
indexVarDims.at(var) != dimension) {
96-
errors.push_back(addDimensionError(var, indexVarDims.at(var),
97-
dimension));
98-
}
99-
else {
55+
Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode);
56+
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
57+
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
58+
} else {
10059
indexVarDims.insert({var, dimension});
10160
}
10261
}
10362
}
10463

105-
return util::join(errors, " ");
64+
return std::make_pair(errors.empty(), util::join(errors, " "));
10665
}
10766

10867
static void addEdges(vector<IndexVar> indexVars, vector<int> modeOrdering,

src/error/error_checks.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <vector>
55
#include <string>
6+
#include <tuple>
67

78
namespace taco {
89
class IndexVar;
@@ -12,15 +13,12 @@ class Shape;
1213

1314
namespace error {
1415

15-
/// Check that the dimensions indexed by the same variable are the same
16-
bool dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
17-
const IndexExpr& expr,
18-
const Shape& shape);
19-
20-
/// Returns error strings for index variables that don't typecheck
21-
std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
22-
const IndexExpr& expr,
23-
const Shape& shape);
16+
/// Check whether all dimensions indexed by the same variable are the same.
17+
/// If they are not, then the first element of the returned tuple will be false,
18+
/// and a human readable error will be returned in the second component.
19+
std::pair<bool, std::string> dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
20+
const IndexExpr& expr,
21+
const Shape& shape);
2422

2523
/// Returns true iff the index expression contains a transposition.
2624
bool containsTranspose(const Format& resultFormat,

src/index_notation/index_notation.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,8 @@ static void check(Assignment assignment) {
763763
auto freeVars = assignment.getLhs().getIndexVars();
764764
auto indexExpr = assignment.getRhs();
765765
auto shape = tensorVar.getType().getShape();
766-
taco_uassert(error::dimensionsTypecheck(freeVars, indexExpr, shape))
767-
<< error::expr_dimension_mismatch << " "
768-
<< error::dimensionTypecheckErrors(freeVars, indexExpr, shape);
766+
auto typecheck = error::dimensionsTypecheck(freeVars, indexExpr, shape);
767+
taco_uassert(typecheck.first) << error::expr_dimension_mismatch << " " << typecheck.second;
769768
}
770769

771770
Assignment Access::operator=(const IndexExpr& expr) {
@@ -1952,9 +1951,9 @@ static bool isValid(Assignment assignment, string* reason) {
19521951
auto result = lhs.getTensorVar();
19531952
auto freeVars = lhs.getIndexVars();
19541953
auto shape = result.getType().getShape();
1955-
if(!error::dimensionsTypecheck(freeVars, rhs, shape)) {
1956-
*reason = error::expr_dimension_mismatch + " " +
1957-
error::dimensionTypecheckErrors(freeVars, rhs, shape);
1954+
auto typecheck = error::dimensionsTypecheck(freeVars, rhs, shape);
1955+
if (!typecheck.first) {
1956+
*reason = error::expr_dimension_mismatch + " " + typecheck.second;
19581957
return false;
19591958
}
19601959
return true;

0 commit comments

Comments
 (0)