|
36 | 36 | graph, |
37 | 37 | ir, |
38 | 38 | ) |
| 39 | +from onnxscript._internal import version_utils |
39 | 40 | from onnxscript.function_libs.torch_lib.ops import common as common_ops |
40 | 41 | from onnxscript.function_libs.torch_lib.registration import torch_op |
41 | 42 | from onnxscript.function_libs.torch_lib.tensor_typing import ( |
@@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized( |
1647 | 1648 | raise NotImplementedError() |
1648 | 1649 |
|
1649 | 1650 |
|
1650 | | -@torch_op("aten::chunk") |
1651 | | -def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
1652 | | - """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
1653 | | - # This will create a Sequence of tensors |
1654 | | - neg_1 = op.Constant(value_ints=[-1]) |
1655 | | - # Get size of specified dim |
1656 | | - self_shape = op.Shape(self) |
1657 | | - dim_size = op.Gather(self_shape, dim, axis=0) |
1658 | | - # Compute size/chunk to get the number of data in one chunk |
1659 | | - num_per_chunk = op.Div(dim_size, chunks) |
1660 | | - num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] |
1661 | | - |
1662 | | - # Compute real chunk number |
1663 | | - num_chunk = op.Div(dim_size, num_per_chunk) |
1664 | | - # Get something like [n, n, n, n, ...], total num_chunk |
1665 | | - list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) |
1666 | | - |
1667 | | - remainder = op.Mod(dim_size, num_per_chunk) |
1668 | | - if remainder > 0: # type: ignore[operator] |
1669 | | - # Append the remainder to the [n, n, n, n, ..., r] |
1670 | | - list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) |
1671 | | - |
1672 | | - return op.SplitToSequence(self, list_split, axis=dim) |
| 1651 | +if version_utils.torch_older_than("2.7.0"): |
| 1652 | + # PyTorch <2.7 does not support determining the number of outputs for the Split op |
| 1653 | + # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6 |
| 1654 | + @torch_op("aten::chunk") |
| 1655 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1656 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1657 | + # This will create a Sequence of tensors |
| 1658 | + neg_1 = op.Constant(value_ints=[-1]) |
| 1659 | + # Get size of specified dim |
| 1660 | + self_shape = op.Shape(self) |
| 1661 | + dim_size = op.Gather(self_shape, dim, axis=0) |
| 1662 | + # Compute size/chunk to get the number of data in one chunk |
| 1663 | + num_per_chunk = op.Div(dim_size, chunks) |
| 1664 | + num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] |
| 1665 | + |
| 1666 | + # Compute real chunk number |
| 1667 | + num_chunk = op.Div(dim_size, num_per_chunk) |
| 1668 | + # Get something like [n, n, n, n, ...], total num_chunk |
| 1669 | + list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) |
| 1670 | + |
| 1671 | + remainder = op.Mod(dim_size, num_per_chunk) |
| 1672 | + if remainder > 0: # type: ignore[operator] |
| 1673 | + # Append the remainder to the [n, n, n, n, ..., r] |
| 1674 | + list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) |
| 1675 | + |
| 1676 | + return op.SplitToSequence(self, list_split, axis=dim) |
| 1677 | +else: |
| 1678 | + |
| 1679 | + @torch_op("aten::chunk", trace_only=True) |
| 1680 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1681 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1682 | + if chunks == 1: |
| 1683 | + return op.Identity(self) |
| 1684 | + return op.Split(self, axis=dim, num_outputs=chunks) |
1673 | 1685 |
|
1674 | 1686 |
|
1675 | 1687 | @torch_op("aten::clamp", trace_only=True) |
|
0 commit comments