Skip to content

Commit 9e96a37

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Back out "Add max_length to JaggedTensor api" (#3174)
Summary: Pull Request resolved: #3174 Original commit changeset: cad35e4c1660 Original Phabricator Diff: D77709565 Reviewed By: aliafzal, spmex, iamzainhuda Differential Revision: D77985223 fbshipit-source-id: 2339220ffbc99c15b4811b6356598d84ac4d6239
1 parent a0de1fd commit 9e96a37

File tree

2 files changed

+1
-43
lines changed

2 files changed

+1
-43
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,6 @@ def _maybe_compute_lengths(
9898
return lengths
9999

100100

101-
@torch.fx.wrap
102-
def _maybe_compute_max_length(lengths: torch.Tensor, max_length: Optional[int]) -> int:
103-
if max_length is None:
104-
if lengths.numel() == 0:
105-
return 0
106-
max_length = int(lengths.max().item())
107-
return max_length
108-
109-
110101
def _maybe_compute_offsets(
111102
lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor]
112103
) -> torch.Tensor:
@@ -590,15 +581,14 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
590581
offsets.
591582
"""
592583

593-
_fields = ["_values", "_weights", "_lengths", "_offsets", "_max_length"]
584+
_fields = ["_values", "_weights", "_lengths", "_offsets"]
594585

595586
def __init__(
596587
self,
597588
values: torch.Tensor,
598589
weights: Optional[torch.Tensor] = None,
599590
lengths: Optional[torch.Tensor] = None,
600591
offsets: Optional[torch.Tensor] = None,
601-
max_length: Optional[int] = None,
602592
) -> None:
603593

604594
self._values: torch.Tensor = values
@@ -610,7 +600,6 @@ def __init__(
610600
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
611601
self._lengths: Optional[torch.Tensor] = lengths
612602
self._offsets: Optional[torch.Tensor] = offsets
613-
self._max_length: Optional[int] = max_length
614603

615604
@staticmethod
616605
def empty(
@@ -641,7 +630,6 @@ def empty(
641630
offsets=torch.empty(0, dtype=lengths_dtype, device=device),
642631
lengths=torch.empty(0, dtype=lengths_dtype, device=device),
643632
weights=weights,
644-
max_length=0,
645633
)
646634

647635
@staticmethod
@@ -924,26 +912,6 @@ def lengths_or_none(self) -> Optional[torch.Tensor]:
924912
"""
925913
return self._lengths
926914

927-
def max_length(self) -> int:
928-
"""
929-
Get the maximum length of the JaggedTensor.
930-
931-
Returns:
932-
int: the maximum length of the JaggedTensor.
933-
"""
934-
_max_length = _maybe_compute_max_length(self.lengths(), self._max_length)
935-
self._max_length = _max_length
936-
return _max_length
937-
938-
def max_length_or_none(self) -> Optional[int]:
939-
"""
940-
Get the maximum length of the JaggedTensor. If not computed, return None.
941-
942-
Returns:
943-
Optional[int]: the maximum length of the JaggedTensor.
944-
"""
945-
return self._max_length
946-
947915
def offsets(self) -> torch.Tensor:
948916
"""
949917
Get JaggedTensor offsets. If not computed, compute it from lengths.
@@ -1005,7 +973,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
1005973
weights = self._weights
1006974
lengths = self._lengths
1007975
offsets = self._offsets
1008-
max_length = self._max_length
1009976
return JaggedTensor(
1010977
values=self._values.to(device, non_blocking=non_blocking),
1011978
weights=(
@@ -1023,7 +990,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
1023990
if offsets is not None
1024991
else None
1025992
),
1026-
max_length=max_length,
1027993
)
1028994

1029995
@torch.jit.unused

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -568,14 +568,6 @@ def test_length_vs_offset(self) -> None:
568568
self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths()))
569569
self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int()))
570570

571-
def test_max_length(self) -> None:
572-
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
573-
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
574-
jt = JaggedTensor(values=values, offsets=offsets)
575-
self.assertIsNone(jt.max_length_or_none())
576-
self.assertEqual(jt.max_length(), 3)
577-
self.assertEqual(jt.max_length_or_none(), 3)
578-
579571
def test_empty(self) -> None:
580572
jt = JaggedTensor.empty(values_dtype=torch.int64)
581573

0 commit comments

Comments
 (0)