22
33#include < torch/csrc/jit/ir/alias_analysis.h>
44#include < torch/csrc/jit/jit_log.h>
5+ #include < torch/csrc/jit/passes/constant_pooling.h>
56#include < torch/csrc/jit/passes/remove_mutation.h>
67
78namespace torch {
89namespace jit {
910
1011namespace {
1112
13+ std::vector<size_t > identifyListArgIndices (const c10::FunctionSchema& schema) {
14+ std::vector<size_t > list_indices;
15+ const auto & args = schema.arguments ();
16+ for (const auto i : c10::irange (args.size ())) {
17+ auto list_type = args[i].type ()->castRaw <ListType>();
18+ if (list_type && list_type->getElementType ()->castRaw <TensorType>()) {
19+ list_indices.push_back (i);
20+ }
21+ }
22+ return list_indices;
23+ }
24+
25+ bool isTensorListConstruct (Node* node) {
26+ if (node->kind () != prim::ListConstruct) {
27+ return false ;
28+ }
29+ const auto type = node->output ()->type ()->castRaw <ListType>();
30+ TORCH_CHECK (type != nullptr );
31+ const auto & elem_type = type->getElementType ();
32+ return elem_type->castRaw <TensorType>();
33+ }
34+
1235class VariadicUpdater {
1336 public:
14- explicit VariadicUpdater (
37+ VariadicUpdater (
1538 std::shared_ptr<Graph> graph,
1639 NodeKind op,
17- NodeKind variadic_op,
18- size_t list_idx = 0 )
40+ NodeKind variadic_op)
1941 : graph_(std::move(graph)),
42+ alias_db_ (graph_),
2043 op_(op),
21- variadic_op_(variadic_op),
22- list_idx_(list_idx) {}
44+ variadic_op_(variadic_op) {}
2345
2446 bool run () {
2547 collectOpNodes (graph_->block ());
@@ -31,90 +53,151 @@ class VariadicUpdater {
3153 }
3254
3355 private:
56+ void recordSchema (Node* op_node) {
57+ const auto & schema = op_node->schema ();
58+ auto it = schema_to_list_indices_.find (schema.name ());
59+ if (it == schema_to_list_indices_.end ()) {
60+ schema_to_list_indices_.emplace (
61+ schema.overload_name (), identifyListArgIndices (schema));
62+ }
63+ }
64+
65+ const std::vector<size_t >& getListIndices (Node* op_node) const {
66+ const auto & schema = op_node->schema ();
67+ auto it = schema_to_list_indices_.find (schema.overload_name ());
68+ TORCH_CHECK (it != schema_to_list_indices_.end ());
69+ return it->second ;
70+ }
71+
3472 void collectOpNodes (Block* block) {
3573 for (auto node : block->nodes ()) {
3674 if (node->kind () == op_) {
3775 op_nodes_.push_back (node);
76+ recordSchema (node);
3877 }
3978 for (Block* b : node->blocks ()) {
4079 collectOpNodes (b);
4180 }
4281 }
4382 }
4483
45- bool replaceWithVariadicOp (Node* op_node) {
84+ bool allListInputsAreValid (Node* op_node) {
4685 const size_t num_inputs = op_node->inputs ().size ();
47- TORCH_CHECK (list_idx_ < num_inputs);
48- if (op_node->input (list_idx_)->node ()->kind () != prim::ListConstruct) {
49- return false ;
86+ for (const auto list_idx : getListIndices (op_node)) {
87+ TORCH_CHECK (list_idx < num_inputs);
88+ const auto list = op_node->input (list_idx)->node ();
89+ // We do not transform ops whose list input can not be moved to the
90+ // position before op. This in turn implies that there is some mutation
91+ // of the input list before op.
92+ if (!isTensorListConstruct (list) ||
93+ !alias_db_.couldMoveBeforeTopologically (list, op_node)) {
94+ return false ;
95+ }
96+ }
97+ return true ;
98+ }
99+
100+ void insertAllInputsBetween (
101+ std::vector<Value*>& inputs,
102+ Node* node,
103+ size_t start_idx,
104+ size_t end_idx) const {
105+ const size_t num_inputs = node->inputs ().size ();
106+ TORCH_CHECK (start_idx <= end_idx && end_idx <= num_inputs);
107+ inputs.insert (
108+ inputs.end (),
109+ node->inputs ().begin () + start_idx,
110+ node->inputs ().begin () + end_idx);
111+ }
112+
113+ void insertIntegerInput (std::vector<Value*>& inputs, size_t input) {
114+ auto constant = graph_->create (prim::Constant);
115+ constant->output ()->setType (c10::IntType::get ());
116+ constant->i_ (attr::value, input);
117+ graph_->prependNode (constant);
118+ inputs.push_back (constant->output ());
119+ }
120+
121+ void deleteOpNodeAndLists (Node* op_node) {
122+ // Collect the lists before we destroy op_node
123+ std::vector<Node*> lists;
124+ const auto & list_indices = getListIndices (op_node);
125+ lists.reserve (list_indices.size ());
126+ for (const size_t list_idx : list_indices) {
127+ auto * list = op_node->input (list_idx)->node ();
128+ lists.push_back (list);
50129 }
51- auto list = op_node->input (list_idx_)->node ();
52- const size_t list_len = list->inputs ().size ();
53130
54- // We do not transform ops whose list input can not be moved to the
55- // position before op. This in turn implies that there is some mutation
56- // of the input list before op.
57- if (!getOrCreateAliasDb ()->couldMoveBeforeTopologically (list, op_node)) {
131+ GRAPH_UPDATE (" Deleting\n " , *op_node);
132+ op_node->destroy ();
133+ for (auto * list : lists) {
134+ if (!list->hasUses ()) {
135+ GRAPH_UPDATE (" Deleting\n " , *list);
136+ list->destroy ();
137+ }
138+ }
139+ }
140+
141+ bool replaceWithVariadicOp (Node* op_node) {
142+ if (!allListInputsAreValid (op_node)) {
58143 return false ;
59144 }
60145
61- // Construct new inputs
62146 std::vector<Value*> inputs;
63- inputs.reserve (num_inputs + list_len - 1 );
64- inputs.insert (
65- inputs.end (),
66- op_node->inputs ().begin (),
67- op_node->inputs ().begin () + list_idx_);
68- inputs.insert (inputs.end (), list->inputs ().begin (), list->inputs ().end ());
69- inputs.insert (
70- inputs.end (),
71- op_node->inputs ().begin () + list_idx_ + 1 ,
72- op_node->inputs ().end ());
147+ size_t cur_idx = 0 ;
148+ std::vector<size_t > list_lens;
149+ for (const size_t list_idx : getListIndices (op_node)) {
150+ insertAllInputsBetween (inputs, op_node, cur_idx, list_idx);
151+ const auto list = op_node->input (list_idx)->node ();
152+ const auto list_len = list->inputs ().size ();
153+ list_lens.push_back (list_len);
154+ insertAllInputsBetween (inputs, list, 0 , list_len);
155+ cur_idx = list_idx + 1 ;
156+ }
157+ insertAllInputsBetween (inputs, op_node, cur_idx, op_node->inputs ().size ());
158+
159+ // We insert these extra integers at the end of the argument list only if we
160+ // have more than one variadic list (the information is redundant when there
161+ // is only one list because the interpreter knows how many arguments there
162+ // are).
163+ if (list_lens.size () > 1 ) {
164+ for (const size_t list_len : list_lens) {
165+ insertIntegerInput (inputs, list_len);
166+ }
167+ }
73168
74169 auto var_op_node = op_node->owningGraph ()->create (variadic_op_, inputs);
75170 var_op_node->output ()->setType (op_node->output ()->type ());
76171 GRAPH_UPDATE (" Adding\n " , *var_op_node);
77172 var_op_node->insertBefore (op_node);
78173 GRAPH_UPDATE (" Replacing\n " , *op_node, " with\n " , *var_op_node);
79174 op_node->output ()->replaceAllUsesWith (var_op_node->output ());
80- GRAPH_UPDATE (" Deleting\n " , *op_node);
81- op_node->destroy ();
82- if (!list->hasUses ()) {
83- GRAPH_UPDATE (" Deleting\n " , *list);
84- list->destroy ();
85- }
175+ deleteOpNodeAndLists (op_node);
86176 return true ;
87177 }
88178
89- AliasDb* getOrCreateAliasDb () {
90- if (!aliasDb_) {
91- aliasDb_ = std::make_unique<AliasDb>(graph_);
92- }
93- return aliasDb_.get ();
94- }
95-
96179 std::shared_ptr<Graph> graph_;
97- std::unique_ptr<AliasDb> aliasDb_ = nullptr ;
98-
99180 std::vector<Node*> op_nodes_;
100181
182+ AliasDb alias_db_;
183+
101184 NodeKind op_;
102185 NodeKind variadic_op_;
103186
104- size_t list_idx_ ;
187+ std::unordered_map<std::string, std::vector< size_t >> schema_to_list_indices_ ;
105188};
106189
107190} // namespace
108191
109192bool UseVariadicOp (
110193 const std::shared_ptr<Graph>& graph,
111194 NodeKind op,
112- NodeKind variadic_op,
113- size_t list_idx) {
195+ NodeKind variadic_op) {
114196 const std::string pass_name = std::string (" variadic " ) + op.toQualString ();
115197 GRAPH_DUMP (" Before " + pass_name, graph);
116- bool changed = VariadicUpdater (graph, op, variadic_op, list_idx ).run ();
198+ bool changed = VariadicUpdater (graph, op, variadic_op).run ();
117199 if (changed) {
200+ ConstantPooling (graph);
118201 GRAPH_DUMP (" After " + pass_name, graph);
119202 }
120203 return changed;
@@ -123,14 +206,13 @@ bool UseVariadicOp(
123206bool RemoveListMutationAndUseVariadicOp (
124207 const std::shared_ptr<Graph>& graph,
125208 NodeKind op,
126- NodeKind variadic_op,
127- size_t list_idx) {
209+ NodeKind variadic_op) {
128210 bool changed_in_last_iter = true ;
129211 bool changed = false ;
130212 while (changed_in_last_iter) {
131213 changed_in_last_iter = RemoveListMutation (graph);
132214 changed_in_last_iter =
133- UseVariadicOp (graph, op, variadic_op, list_idx ) || changed_in_last_iter;
215+ UseVariadicOp (graph, op, variadic_op) || changed_in_last_iter;
134216 changed = changed || changed_in_last_iter;
135217 }
136218 return changed;
@@ -139,15 +221,13 @@ bool RemoveListMutationAndUseVariadicOp(
139221bool UseVariadicCat (const std::shared_ptr<Graph>& graph) {
140222 return UseVariadicOp (graph, aten::cat, prim::VarConcat);
141223}
142-
143224bool RemoveListMutationAndUseVariadicCat (const std::shared_ptr<Graph>& graph) {
144225 return RemoveListMutationAndUseVariadicOp (graph, aten::cat, prim::VarConcat);
145226}
146227
147228bool UseVariadicStack (const std::shared_ptr<Graph>& graph) {
148229 return UseVariadicOp (graph, aten::stack, prim::VarStack);
149230}
150-
151231bool RemoveListMutationAndUseVariadicStack (
152232 const std::shared_ptr<Graph>& graph) {
153233 return RemoveListMutationAndUseVariadicOp (graph, aten::stack, prim::VarStack);
0 commit comments