Skip to content

Commit 391eb1d

Browse files
Mike Iovinefacebook-github-bot
authored andcommitted
[JIT] UseVariadicOp handles multiple lists (pytorch#66288)
Summary: Pull Request resolved: pytorch#66288 This change makes it so `UseVariadicOp` can transform ops with many Tensor list inputs. Input pattern: ``` %output : Type = op(%list_1, %arg_1, %list_2, %list_3) ``` Output pattern: ``` %output : Type = variadic_op(%list_11, ..., %list_1N, %arg_1, %list_21, ..., %list_2M, %list_31, ..., %list_3K, N, M, K) ``` The length of each list is passed at the end of the variadic op so that the op implementation can process the inputs appropriately. This also frees us from needing to update `hasVarArgs` in static runtime each time we add a variadic op. This diff also makes `UseVariadicOp` more robust. Before, `list_idx` was passed as an argument. Now, `VariadicUpdater` determines `list_idx` from the node's schema. Test Plan: Existing variadic ops do not break: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest` Reviewed By: d1jang Differential Revision: D31450811 fbshipit-source-id: 808fcc3ae8940b9e602586f38f8cf9154c9a6462
1 parent c7121ae commit 391eb1d

File tree

3 files changed

+145
-70
lines changed

3 files changed

+145
-70
lines changed

torch/csrc/jit/passes/variadic_ops.cpp

Lines changed: 130 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,46 @@
22

33
#include <torch/csrc/jit/ir/alias_analysis.h>
44
#include <torch/csrc/jit/jit_log.h>
5+
#include <torch/csrc/jit/passes/constant_pooling.h>
56
#include <torch/csrc/jit/passes/remove_mutation.h>
67

78
namespace torch {
89
namespace jit {
910

1011
namespace {
1112

13+
std::vector<size_t> identifyListArgIndices(const c10::FunctionSchema& schema) {
14+
std::vector<size_t> list_indices;
15+
const auto& args = schema.arguments();
16+
for (const auto i : c10::irange(args.size())) {
17+
auto list_type = args[i].type()->castRaw<ListType>();
18+
if (list_type && list_type->getElementType()->castRaw<TensorType>()) {
19+
list_indices.push_back(i);
20+
}
21+
}
22+
return list_indices;
23+
}
24+
25+
bool isTensorListConstruct(Node* node) {
26+
if (node->kind() != prim::ListConstruct) {
27+
return false;
28+
}
29+
const auto type = node->output()->type()->castRaw<ListType>();
30+
TORCH_CHECK(type != nullptr);
31+
const auto& elem_type = type->getElementType();
32+
return elem_type->castRaw<TensorType>();
33+
}
34+
1235
class VariadicUpdater {
1336
public:
14-
explicit VariadicUpdater(
37+
VariadicUpdater(
1538
std::shared_ptr<Graph> graph,
1639
NodeKind op,
17-
NodeKind variadic_op,
18-
size_t list_idx = 0)
40+
NodeKind variadic_op)
1941
: graph_(std::move(graph)),
42+
alias_db_(graph_),
2043
op_(op),
21-
variadic_op_(variadic_op),
22-
list_idx_(list_idx) {}
44+
variadic_op_(variadic_op) {}
2345

2446
bool run() {
2547
collectOpNodes(graph_->block());
@@ -31,90 +53,151 @@ class VariadicUpdater {
3153
}
3254

3355
private:
56+
void recordSchema(Node* op_node) {
57+
const auto& schema = op_node->schema();
58+
auto it = schema_to_list_indices_.find(schema.name());
59+
if (it == schema_to_list_indices_.end()) {
60+
schema_to_list_indices_.emplace(
61+
schema.overload_name(), identifyListArgIndices(schema));
62+
}
63+
}
64+
65+
const std::vector<size_t>& getListIndices(Node* op_node) const {
66+
const auto& schema = op_node->schema();
67+
auto it = schema_to_list_indices_.find(schema.overload_name());
68+
TORCH_CHECK(it != schema_to_list_indices_.end());
69+
return it->second;
70+
}
71+
3472
void collectOpNodes(Block* block) {
3573
for (auto node : block->nodes()) {
3674
if (node->kind() == op_) {
3775
op_nodes_.push_back(node);
76+
recordSchema(node);
3877
}
3978
for (Block* b : node->blocks()) {
4079
collectOpNodes(b);
4180
}
4281
}
4382
}
4483

45-
bool replaceWithVariadicOp(Node* op_node) {
84+
bool allListInputsAreValid(Node* op_node) {
4685
const size_t num_inputs = op_node->inputs().size();
47-
TORCH_CHECK(list_idx_ < num_inputs);
48-
if (op_node->input(list_idx_)->node()->kind() != prim::ListConstruct) {
49-
return false;
86+
for (const auto list_idx : getListIndices(op_node)) {
87+
TORCH_CHECK(list_idx < num_inputs);
88+
const auto list = op_node->input(list_idx)->node();
89+
// We do not transform ops whose list input can not be moved to the
90+
// position before op. This in turn implies that there is some mutation
91+
// of the input list before op.
92+
if (!isTensorListConstruct(list) ||
93+
!alias_db_.couldMoveBeforeTopologically(list, op_node)) {
94+
return false;
95+
}
96+
}
97+
return true;
98+
}
99+
100+
void insertAllInputsBetween(
101+
std::vector<Value*>& inputs,
102+
Node* node,
103+
size_t start_idx,
104+
size_t end_idx) const {
105+
const size_t num_inputs = node->inputs().size();
106+
TORCH_CHECK(start_idx <= end_idx && end_idx <= num_inputs);
107+
inputs.insert(
108+
inputs.end(),
109+
node->inputs().begin() + start_idx,
110+
node->inputs().begin() + end_idx);
111+
}
112+
113+
void insertIntegerInput(std::vector<Value*>& inputs, size_t input) {
114+
auto constant = graph_->create(prim::Constant);
115+
constant->output()->setType(c10::IntType::get());
116+
constant->i_(attr::value, input);
117+
graph_->prependNode(constant);
118+
inputs.push_back(constant->output());
119+
}
120+
121+
void deleteOpNodeAndLists(Node* op_node) {
122+
// Collect the lists before we destroy op_node
123+
std::vector<Node*> lists;
124+
const auto& list_indices = getListIndices(op_node);
125+
lists.reserve(list_indices.size());
126+
for (const size_t list_idx : list_indices) {
127+
auto* list = op_node->input(list_idx)->node();
128+
lists.push_back(list);
50129
}
51-
auto list = op_node->input(list_idx_)->node();
52-
const size_t list_len = list->inputs().size();
53130

54-
// We do not transform ops whose list input can not be moved to the
55-
// position before op. This in turn implies that there is some mutation
56-
// of the input list before op.
57-
if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, op_node)) {
131+
GRAPH_UPDATE("Deleting\n", *op_node);
132+
op_node->destroy();
133+
for (auto* list : lists) {
134+
if (!list->hasUses()) {
135+
GRAPH_UPDATE("Deleting\n", *list);
136+
list->destroy();
137+
}
138+
}
139+
}
140+
141+
bool replaceWithVariadicOp(Node* op_node) {
142+
if (!allListInputsAreValid(op_node)) {
58143
return false;
59144
}
60145

61-
// Construct new inputs
62146
std::vector<Value*> inputs;
63-
inputs.reserve(num_inputs + list_len - 1);
64-
inputs.insert(
65-
inputs.end(),
66-
op_node->inputs().begin(),
67-
op_node->inputs().begin() + list_idx_);
68-
inputs.insert(inputs.end(), list->inputs().begin(), list->inputs().end());
69-
inputs.insert(
70-
inputs.end(),
71-
op_node->inputs().begin() + list_idx_ + 1,
72-
op_node->inputs().end());
147+
size_t cur_idx = 0;
148+
std::vector<size_t> list_lens;
149+
for (const size_t list_idx : getListIndices(op_node)) {
150+
insertAllInputsBetween(inputs, op_node, cur_idx, list_idx);
151+
const auto list = op_node->input(list_idx)->node();
152+
const auto list_len = list->inputs().size();
153+
list_lens.push_back(list_len);
154+
insertAllInputsBetween(inputs, list, 0, list_len);
155+
cur_idx = list_idx + 1;
156+
}
157+
insertAllInputsBetween(inputs, op_node, cur_idx, op_node->inputs().size());
158+
159+
// We insert these extra integers at the end of the argument list only if we
160+
// have more than one variadic list (the information is redundant when there
161+
// is only one list because the interpreter knows how many arguments there
162+
// are).
163+
if (list_lens.size() > 1) {
164+
for (const size_t list_len : list_lens) {
165+
insertIntegerInput(inputs, list_len);
166+
}
167+
}
73168

74169
auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
75170
var_op_node->output()->setType(op_node->output()->type());
76171
GRAPH_UPDATE("Adding\n", *var_op_node);
77172
var_op_node->insertBefore(op_node);
78173
GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
79174
op_node->output()->replaceAllUsesWith(var_op_node->output());
80-
GRAPH_UPDATE("Deleting\n", *op_node);
81-
op_node->destroy();
82-
if (!list->hasUses()) {
83-
GRAPH_UPDATE("Deleting\n", *list);
84-
list->destroy();
85-
}
175+
deleteOpNodeAndLists(op_node);
86176
return true;
87177
}
88178

89-
AliasDb* getOrCreateAliasDb() {
90-
if (!aliasDb_) {
91-
aliasDb_ = std::make_unique<AliasDb>(graph_);
92-
}
93-
return aliasDb_.get();
94-
}
95-
96179
std::shared_ptr<Graph> graph_;
97-
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
98-
99180
std::vector<Node*> op_nodes_;
100181

182+
AliasDb alias_db_;
183+
101184
NodeKind op_;
102185
NodeKind variadic_op_;
103186

104-
size_t list_idx_;
187+
std::unordered_map<std::string, std::vector<size_t>> schema_to_list_indices_;
105188
};
106189

107190
} // namespace
108191

109192
bool UseVariadicOp(
110193
const std::shared_ptr<Graph>& graph,
111194
NodeKind op,
112-
NodeKind variadic_op,
113-
size_t list_idx) {
195+
NodeKind variadic_op) {
114196
const std::string pass_name = std::string("variadic ") + op.toQualString();
115197
GRAPH_DUMP("Before " + pass_name, graph);
116-
bool changed = VariadicUpdater(graph, op, variadic_op, list_idx).run();
198+
bool changed = VariadicUpdater(graph, op, variadic_op).run();
117199
if (changed) {
200+
ConstantPooling(graph);
118201
GRAPH_DUMP("After " + pass_name, graph);
119202
}
120203
return changed;
@@ -123,14 +206,13 @@ bool UseVariadicOp(
123206
bool RemoveListMutationAndUseVariadicOp(
124207
const std::shared_ptr<Graph>& graph,
125208
NodeKind op,
126-
NodeKind variadic_op,
127-
size_t list_idx) {
209+
NodeKind variadic_op) {
128210
bool changed_in_last_iter = true;
129211
bool changed = false;
130212
while (changed_in_last_iter) {
131213
changed_in_last_iter = RemoveListMutation(graph);
132214
changed_in_last_iter =
133-
UseVariadicOp(graph, op, variadic_op, list_idx) || changed_in_last_iter;
215+
UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
134216
changed = changed || changed_in_last_iter;
135217
}
136218
return changed;
@@ -139,15 +221,13 @@ bool RemoveListMutationAndUseVariadicOp(
139221
bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
140222
return UseVariadicOp(graph, aten::cat, prim::VarConcat);
141223
}
142-
143224
bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
144225
return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
145226
}
146227

147228
bool UseVariadicStack(const std::shared_ptr<Graph>& graph) {
148229
return UseVariadicOp(graph, aten::stack, prim::VarStack);
149230
}
150-
151231
bool RemoveListMutationAndUseVariadicStack(
152232
const std::shared_ptr<Graph>& graph) {
153233
return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack);

torch/csrc/jit/passes/variadic_ops.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,27 @@
55
namespace torch {
66
namespace jit {
77

8-
// Replaces the `aten::cat` ops in the given graph with variadic cat ops.
9-
// Returns true if the graph is modified.
10-
TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
11-
12-
TORCH_API bool RemoveListMutationAndUseVariadicCat(
13-
const std::shared_ptr<Graph>& graph);
14-
15-
// Replaces the `aten::stack` ops in the given graph with variadic cat ops.
16-
// Returns true if the graph is modified.
17-
TORCH_API bool UseVariadicStack(const std::shared_ptr<Graph>& graph);
18-
19-
TORCH_API bool RemoveListMutationAndUseVariadicStack(
20-
const std::shared_ptr<Graph>& graph);
21-
8+
// Try to replace an op that takes a list input with another op that takes a
9+
// variadic number of arguments.
2210
TORCH_API bool UseVariadicOp(
2311
const std::shared_ptr<Graph>& graph,
2412
NodeKind op,
25-
NodeKind variadic_op,
26-
size_t list_idx = 0);
13+
NodeKind variadic_op);
2714

2815
TORCH_API bool RemoveListMutationAndUseVariadicOp(
2916
const std::shared_ptr<Graph>& graph,
3017
NodeKind op,
31-
NodeKind variadic_op,
32-
size_t list_idx = 0);
18+
NodeKind variadic_op);
19+
20+
// Convenient functions for replacing aten::stack/aten::cat with their
21+
// variadic versions.
22+
TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
23+
TORCH_API bool RemoveListMutationAndUseVariadicCat(
24+
const std::shared_ptr<Graph>& graph);
25+
26+
TORCH_API bool UseVariadicStack(const std::shared_ptr<Graph>& graph);
27+
TORCH_API bool RemoveListMutationAndUseVariadicStack(
28+
const std::shared_ptr<Graph>& graph);
3329

3430
} // namespace jit
3531
} // namespace torch

torch/csrc/jit/runtime/static/impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ void OptimizeGraph(
9595
graph,
9696
c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
9797
c10::Symbol::fromQualString(
98-
"fb::variadic_sigrid_transforms_torch_bind"),
99-
1 /* list_idx */);
98+
"fb::variadic_sigrid_transforms_torch_bind"));
10099
FuseSignLog1P(graph);
101100

102101
// TODO: we can avoid this guard by moving operations

0 commit comments

Comments
 (0)