@@ -916,8 +916,16 @@ int main(int argc, char* argv[]) {
916916 printCompute = true ;
917917 }
918918
919- // Load tensors
919+ // pre-parse expression, to determine existence and order of loaded tensors
920920 map<string,TensorBase> loadedTensors;
921+ TensorBase temp_tensor;
922+ parser::Parser temp_parser (exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, 42 );
923+ try {
924+ temp_parser.parse ();
925+ temp_tensor = temp_parser.getResultTensor ();
926+ } catch (parser::ParseError& e) {
927+ return reportError (e.getMessage (), 6 );
928+ }
921929
922930 // Load tensors
923931 for (auto & tensorNames : inputFilenames) {
@@ -928,7 +936,32 @@ int main(int argc, char* argv[]) {
928936 return reportError (" Loaded tensors can only be type double" , 7 );
929937 }
930938
931- Format format = util::contains (formats, name) ? formats.at (name) : Dense;
939+ // make sure the tensor exists in the expression (and stash its order)
940+ int found_tensor_order;
941+ bool found = false ;
942+ for (auto a : getArgumentAccesses (temp_tensor.getAssignment ().concretize ())) {
943+ if (a.getTensorVar ().getName () == name) {
944+ found_tensor_order = a.getIndexVars ().size ();
945+ found = true ;
946+ break ;
947+ }
948+ }
949+ if (found == false ) {
950+ return reportError (" Cannot load '" + filename + " ': no tensor '" + name + " ' found in expression" , 8 );
951+ }
952+
953+ Format format;
954+ if (util::contains (formats, name)) {
955+ // format of this tensor is specified on the command line, use it
956+ format = formats.at (name);
957+ } else {
958+ // create a dense default format of the correct order
959+ std::vector<ModeFormat> modes;
960+ for (int i = 0 ; i < found_tensor_order; i++) {
961+ modes.push_back (Dense);
962+ }
963+ format = Format ({ModeFormatPack (modes)});
964+ }
932965 TensorBase tensor;
933966 TOOL_BENCHMARK_TIMER (tensor = read (filename,format,false ),
934967 name+" file read:" , timevalue);
0 commit comments