Skip to content

Commit d774a63

Browse files
coufonpytorchmergebot
authored andcommitted
[StaticRuntime] Support a new pattern for ClipRangesToGatherToOffsets (pytorch#146931)
Summary: Support the following new pattern for ClipRangesToGatherToOffsets: Before optimization: ``` %18267 : Tensor, %18268 : Tensor = fb::clip_ranges_gather(%int_77.1, %getitem_2484.1, %493) %getattr_368.1 : int = prim::dtype(%18267) %to_443.1 : Tensor = aten::to(%18268, %getattr_368.1, %self._maybe_compute_kjt_to_jt_dict.is_weighted, %self._maybe_compute_kjt_to_jt_dict.is_weighted) %lengths_to_offsets_490.1 : Tensor = fb::lengths_to_offsets(%to_443.1, %8) ``` After optimization: ``` %18297 : int = prim::dtype(%int_77.1) %18298 : Tensor, %18299 : Tensor = fb::clip_ranges_gather_to_offsets(%int_77.1, %getitem_2484.1, %493, %8, %18297) ``` Reviewed By: garroud Differential Revision: D69373835 Pull Request resolved: pytorch#146931 Approved by: https://github.com/hanyilou123
1 parent ae5cc19 commit d774a63

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,27 @@ namespace {
255255
fuse.runOnGraph(graph);
256256
}
257257

258+
// Similar to ClipRangesToGatherToOffsets, but for the case where type of aten::to is from
259+
// gather_ranges's data output instead of the graph input.
260+
[[maybe_unused]] void ClipRangesToGatherToOffsetsV2(
261+
std::shared_ptr<torch::jit::Graph>& graph) {
262+
std::string pattern = R"IR(
263+
graph(%a, %b, %c, %d, %to0_in0):
264+
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
265+
%y0_type : int = prim::dtype(%y0)
266+
%y2 : Tensor = aten::to(%y1, %y0_type, %to0_in0, %to0_in0)
267+
%y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
268+
return (%y3, %y0))IR";
269+
std::string fused_pattern = R"IR(
270+
graph(%a, %b, %c, %d, %to0_in0):
271+
%a_type : int = prim::dtype(%a)
272+
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %a_type)
273+
return (%y1, %y0))IR";
274+
SubgraphRewriter fuse;
275+
fuse.RegisterRewritePattern(pattern, fused_pattern);
276+
fuse.runOnGraph(graph);
277+
}
278+
258279
[[maybe_unused]] void ToLengthsToOffsets(
259280
std::shared_ptr<torch::jit::Graph>& graph) {
260281
std::string pattern = R"IR(
@@ -389,7 +410,8 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
389410
// prioritize clip_ranges+gather_ranges+sigrid_hash fusion over
390411
// clip_ranges+gather_ranges
391412
ClipRangesGather(graph);
392-
413+
// Must run before ClipRangesToGatherToOffsets.
414+
ClipRangesToGatherToOffsetsV2(graph);
393415
ClipRangesToGatherToOffsets(graph);
394416
}
395417

0 commit comments

Comments
 (0)