Skip to content

Commit fda53cd

Browse files
Merge pull request #371 from Infinoid/fix-tensor-loading
Add an error message for invalid input tensor names.
2 parents 475fcee + bc0d188 commit fda53cd

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/storage/storage.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ struct TensorStorage::Content {
3232
int order = (int)dimensions.size();
3333

3434
taco_iassert(order <= INT_MAX && componentType.getNumBits() <= INT_MAX);
35+
taco_uassert(order == format.getOrder()) <<
36+
"The number of format mode types (" << format.getOrder() << ") " <<
37+
"must match the tensor order (" << dimensions.size() << ").";
3538
vector<int32_t> dimensionsInt32(order);
3639
vector<int32_t> modeOrdering(order);
3740
vector<taco_mode_t> modeTypes(order);

tools/taco.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,16 @@ int main(int argc, char* argv[]) {
916916
printCompute = true;
917917
}
918918

919-
// Load tensors
919+
// pre-parse expression, to determine existence and order of loaded tensors
920920
map<string,TensorBase> loadedTensors;
921+
TensorBase temp_tensor;
922+
parser::Parser temp_parser(exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, 42);
923+
try {
924+
temp_parser.parse();
925+
temp_tensor = temp_parser.getResultTensor();
926+
} catch (parser::ParseError& e) {
927+
return reportError(e.getMessage(), 6);
928+
}
921929

922930
// Load tensors
923931
for (auto& tensorNames : inputFilenames) {
@@ -928,7 +936,32 @@ int main(int argc, char* argv[]) {
928936
return reportError("Loaded tensors can only be type double", 7);
929937
}
930938

931-
Format format = util::contains(formats, name) ? formats.at(name) : Dense;
939+
// make sure the tensor exists in the expression (and stash its order)
940+
int found_tensor_order;
941+
bool found = false;
942+
for (auto a : getArgumentAccesses(temp_tensor.getAssignment().concretize())) {
943+
if (a.getTensorVar().getName() == name) {
944+
found_tensor_order = a.getIndexVars().size();
945+
found = true;
946+
break;
947+
}
948+
}
949+
if(found == false) {
950+
return reportError("Cannot load '" + filename + "': no tensor '" + name + "' found in expression", 8);
951+
}
952+
953+
Format format;
954+
if(util::contains(formats, name)) {
955+
// format of this tensor is specified on the command line, use it
956+
format = formats.at(name);
957+
} else {
958+
// create a dense default format of the correct order
959+
std::vector<ModeFormat> modes;
960+
for(int i = 0; i < found_tensor_order; i++) {
961+
modes.push_back(Dense);
962+
}
963+
format = Format({ModeFormatPack(modes)});
964+
}
932965
TensorBase tensor;
933966
TOOL_BENCHMARK_TIMER(tensor = read(filename,format,false),
934967
name+" file read:", timevalue);

0 commit comments

Comments
 (0)