Skip to content

Commit 9cbbbbc

Browse files
authored
[Layout] Support layout forward with multi dimension (#867)
* Enhance LayoutNode::Forward method to handle variable transformations more robustly - Updated the method to check for a minimum number of input dimensions. - Introduced a mechanism to transform the last InputDim() elements of the input variables. - Concatenated transformed variables with the remaining input variables for a comprehensive output. * Refactor LayoutNode::Forward method for improved readability - Removed unnecessary whitespace to enhance code clarity. - Maintained existing functionality while streamlining the transformation process of input variables.
1 parent 86aaf3c commit 9cbbbbc

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/layout/layout.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,32 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
115115
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
116116
if (vars.empty())
117117
return forward_index_;
118-
ICHECK_EQ(vars.size(), InputDim());
118+
ICHECK_GE(vars.size(), InputDim());
119+
120+
// Take the last InputDim() elements for transformation
121+
Array<PrimExpr> transform_vars;
122+
for (size_t i = vars.size() - InputDim(); i < vars.size(); i++) {
123+
transform_vars.push_back(vars[i]);
124+
}
125+
119126
Map<Var, PrimExpr> vmap;
120127
for (size_t i = 0; i < InputDim(); i++) {
121-
vmap.Set(InputPlaceholder(i), vars[i]);
128+
vmap.Set(InputPlaceholder(i), transform_vars[i]);
122129
}
123-
return forward_index_.Map(
130+
131+
Array<PrimExpr> transformed = forward_index_.Map(
124132
[&](const PrimExpr &e) { return Substitute(e, vmap); });
133+
134+
// Concatenate with the remaining elements from vars
135+
Array<PrimExpr> result;
136+
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
137+
result.push_back(vars[i]);
138+
}
139+
for (const auto &expr : transformed) {
140+
result.push_back(expr);
141+
}
142+
143+
return result;
125144
}
126145

127146
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,

0 commit comments

Comments
 (0)