@@ -72,6 +72,54 @@ TEST(scheduling, splitIndexStmt) {
7272 ASSERT_TRUE (equals (a (i) = b (i), i2Forall.getStmt ()));
7373}
7474
75+ TEST (scheduling, fuseDenseLoops) {
76+ auto dim = 4 ;
77+ Tensor<int > A (" A" , {dim, dim, dim}, {Dense, Dense, Dense});
78+ Tensor<int > B (" B" , {dim, dim, dim}, {Dense, Dense, Dense});
79+ Tensor<int > expected (" expected" , {dim, dim, dim}, {Dense, Dense, Dense});
80+ IndexVar f (" f" ), g (" g" );
81+ for (int i = 0 ; i < dim; i++) {
82+ for (int j = 0 ; j < dim; j++) {
83+ for (int k = 0 ; k < dim; k++) {
84+ A.insert ({i, j, k}, i + j + k);
85+ B.insert ({i, j, k}, i + j + k);
86+ expected.insert ({i, j, k}, 2 * (i + j + k));
87+ }
88+ }
89+ }
90+ A.pack ();
91+ B.pack ();
92+ expected.pack ();
93+
94+ // Helper function to evaluate the target statement and verify the results.
95+ // It takes in a function that applies some scheduling transforms to the
96+ // input IndexStmt, and applies to the point-wise tensor addition below.
97+ // The test is structured this way as TACO does its best to avoid re-compilation
98+ // whenever possible. I.e. changing the stmt that a tensor is compiled with
99+ // doesn't cause compilation to occur again.
100+ auto testFn = [&](std::function<IndexStmt (IndexStmt)> modifier) {
101+ Tensor<int > C (" C" , {dim, dim, dim}, {Dense, Dense, Dense});
102+ C (i, j, k) = A (i, j, k) + B (i, j, k);
103+ auto stmt = C.getAssignment ().concretize ();
104+ C.compile (modifier (stmt));
105+ C.evaluate ();
106+ ASSERT_TRUE (equals (C, expected)) << endl << C << endl << expected << endl;
107+ };
108+
109+ // First, a sanity check with no transformations.
110+ testFn ([](IndexStmt stmt) { return stmt; });
111+ // Next, fuse the outer two loops. This tests the original bug in #355.
112+ testFn ([&](IndexStmt stmt) {
113+ return stmt.fuse (i, j, f);
114+ });
115+ // Lastly, fuse all of the loops into a single loop. This ensures that
116+ // locators with a chain of ancestors have all of their dependencies
117+ // generated in a valid ordering.
118+ testFn ([&](IndexStmt stmt) {
119+ return stmt.fuse (i, j, f).fuse (f, k, g);
120+ });
121+ }
122+
75123TEST (scheduling, lowerDenseMatrixMul) {
76124 Tensor<double > A (" A" , {4 , 4 }, {Dense, Dense});
77125 Tensor<double > B (" B" , {4 , 4 }, {Dense, Dense});
0 commit comments