Skip to content

Commit f56dd22

Browse files
yunjiangsterfacebook-github-bot
authored andcommitted
Add max_length to JaggedTensor api (#3160)
Summary: Pull Request resolved: #3160 We often need to recompute the max_length of a JaggedTensor on the fly, which makes the compute graph fragmented. This change follows the caching approach of existing lengths/offsets implementation for max_length. Reviewed By: shruthign Differential Revision: D77709565 fbshipit-source-id: cad35e4c16607c859291d18d1b4f90f98d8a0002
1 parent fc37e53 commit f56dd22

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ def _maybe_compute_lengths(
9898
return lengths
9999

100100

101+
@torch.fx.wrap
102+
def _maybe_compute_max_length(lengths: torch.Tensor, max_length: Optional[int]) -> int:
103+
if max_length is None:
104+
if lengths.numel() == 0:
105+
return 0
106+
max_length = int(lengths.max().item())
107+
return max_length
108+
109+
101110
def _maybe_compute_offsets(
102111
lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor]
103112
) -> torch.Tensor:
@@ -581,14 +590,15 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
581590
offsets.
582591
"""
583592

584-
_fields = ["_values", "_weights", "_lengths", "_offsets"]
593+
_fields = ["_values", "_weights", "_lengths", "_offsets", "_max_length"]
585594

586595
def __init__(
587596
self,
588597
values: torch.Tensor,
589598
weights: Optional[torch.Tensor] = None,
590599
lengths: Optional[torch.Tensor] = None,
591600
offsets: Optional[torch.Tensor] = None,
601+
max_length: Optional[int] = None,
592602
) -> None:
593603

594604
self._values: torch.Tensor = values
@@ -600,6 +610,7 @@ def __init__(
600610
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
601611
self._lengths: Optional[torch.Tensor] = lengths
602612
self._offsets: Optional[torch.Tensor] = offsets
613+
self._max_length: Optional[int] = max_length
603614

604615
@staticmethod
605616
def empty(
@@ -630,6 +641,7 @@ def empty(
630641
offsets=torch.empty(0, dtype=lengths_dtype, device=device),
631642
lengths=torch.empty(0, dtype=lengths_dtype, device=device),
632643
weights=weights,
644+
max_length=0,
633645
)
634646

635647
@staticmethod
@@ -912,6 +924,26 @@ def lengths_or_none(self) -> Optional[torch.Tensor]:
912924
"""
913925
return self._lengths
914926

927+
def max_length(self) -> int:
928+
"""
929+
Get the maximum length of the JaggedTensor.
930+
931+
Returns:
932+
int: the maximum length of the JaggedTensor.
933+
"""
934+
_max_length = _maybe_compute_max_length(self.lengths(), self._max_length)
935+
self._max_length = _max_length
936+
return _max_length
937+
938+
def max_length_or_none(self) -> Optional[int]:
939+
"""
940+
Get the maximum length of the JaggedTensor. If not computed, return None.
941+
942+
Returns:
943+
Optional[int]: the maximum length of the JaggedTensor.
944+
"""
945+
return self._max_length
946+
915947
def offsets(self) -> torch.Tensor:
916948
"""
917949
Get JaggedTensor offsets. If not computed, compute it from lengths.
@@ -973,6 +1005,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
9731005
weights = self._weights
9741006
lengths = self._lengths
9751007
offsets = self._offsets
1008+
max_length = self._max_length
9761009
return JaggedTensor(
9771010
values=self._values.to(device, non_blocking=non_blocking),
9781011
weights=(
@@ -990,6 +1023,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
9901023
if offsets is not None
9911024
else None
9921025
),
1026+
max_length=max_length,
9931027
)
9941028

9951029
@torch.jit.unused

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,14 @@ def test_length_vs_offset(self) -> None:
568568
self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths()))
569569
self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int()))
570570

571+
def test_max_length(self) -> None:
572+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
573+
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
574+
jt = JaggedTensor(values=values, offsets=offsets)
575+
self.assertIsNone(jt.max_length_or_none())
576+
self.assertEqual(jt.max_length(), 3)
577+
self.assertEqual(jt.max_length_or_none(), 3)
578+
571579
def test_empty(self) -> None:
572580
jt = JaggedTensor.empty(values_dtype=torch.int64)
573581

0 commit comments

Comments
 (0)