@@ -255,6 +255,27 @@ namespace {
255255 fuse.runOnGraph (graph);
256256}
257257
258+ // Similar to ClipRangesToGatherToOffsets, but for the case where type of aten::to is from
259+ // gather_ranges's data output instead of the graph input.
260+ [[maybe_unused]] void ClipRangesToGatherToOffsetsV2 (
261+ std::shared_ptr<torch::jit::Graph>& graph) {
262+ std::string pattern = R"IR(
263+ graph(%a, %b, %c, %d, %to0_in0):
264+ %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
265+ %y0_type : int = prim::dtype(%y0)
266+ %y2 : Tensor = aten::to(%y1, %y0_type, %to0_in0, %to0_in0)
267+ %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
268+ return (%y3, %y0))IR" ;
269+ std::string fused_pattern = R"IR(
270+ graph(%a, %b, %c, %d, %to0_in0):
271+ %a_type : int = prim::dtype(%a)
272+ %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %a_type)
273+ return (%y1, %y0))IR" ;
274+ SubgraphRewriter fuse;
275+ fuse.RegisterRewritePattern (pattern, fused_pattern);
276+ fuse.runOnGraph (graph);
277+ }
278+
258279[[maybe_unused]] void ToLengthsToOffsets (
259280 std::shared_ptr<torch::jit::Graph>& graph) {
260281 std::string pattern = R"IR(
@@ -389,7 +410,8 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
389410 // prioritize clip_ranges+gather_ranges+sigrid_hash fusion over
390411 // clip_ranges+gather_ranges
391412 ClipRangesGather (graph);
392-
413+ // Must run before ClipRangesToGatherToOffsets.
414+ ClipRangesToGatherToOffsetsV2 (graph);
393415 ClipRangesToGatherToOffsets (graph);
394416 }
395417
0 commit comments