@@ -98,15 +98,6 @@ def _maybe_compute_lengths(
98
98
return lengths
99
99
100
100
101
- @torch .fx .wrap
102
- def _maybe_compute_max_length (lengths : torch .Tensor , max_length : Optional [int ]) -> int :
103
- if max_length is None :
104
- if lengths .numel () == 0 :
105
- return 0
106
- max_length = int (lengths .max ().item ())
107
- return max_length
108
-
109
-
110
101
def _maybe_compute_offsets (
111
102
lengths : Optional [torch .Tensor ], offsets : Optional [torch .Tensor ]
112
103
) -> torch .Tensor :
@@ -590,15 +581,14 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
590
581
offsets.
591
582
"""
592
583
593
- _fields = ["_values" , "_weights" , "_lengths" , "_offsets" , "_max_length" ]
584
+ _fields = ["_values" , "_weights" , "_lengths" , "_offsets" ]
594
585
595
586
def __init__ (
596
587
self ,
597
588
values : torch .Tensor ,
598
589
weights : Optional [torch .Tensor ] = None ,
599
590
lengths : Optional [torch .Tensor ] = None ,
600
591
offsets : Optional [torch .Tensor ] = None ,
601
- max_length : Optional [int ] = None ,
602
592
) -> None :
603
593
604
594
self ._values : torch .Tensor = values
@@ -610,7 +600,6 @@ def __init__(
610
600
_assert_tensor_has_no_elements_or_has_integers (lengths , "lengths" )
611
601
self ._lengths : Optional [torch .Tensor ] = lengths
612
602
self ._offsets : Optional [torch .Tensor ] = offsets
613
- self ._max_length : Optional [int ] = max_length
614
603
615
604
@staticmethod
616
605
def empty (
@@ -641,7 +630,6 @@ def empty(
641
630
offsets = torch .empty (0 , dtype = lengths_dtype , device = device ),
642
631
lengths = torch .empty (0 , dtype = lengths_dtype , device = device ),
643
632
weights = weights ,
644
- max_length = 0 ,
645
633
)
646
634
647
635
@staticmethod
@@ -924,26 +912,6 @@ def lengths_or_none(self) -> Optional[torch.Tensor]:
924
912
"""
925
913
return self ._lengths
926
914
927
- def max_length (self ) -> int :
928
- """
929
- Get the maximum length of the JaggedTensor.
930
-
931
- Returns:
932
- int: the maximum length of the JaggedTensor.
933
- """
934
- _max_length = _maybe_compute_max_length (self .lengths (), self ._max_length )
935
- self ._max_length = _max_length
936
- return _max_length
937
-
938
- def max_length_or_none (self ) -> Optional [int ]:
939
- """
940
- Get the maximum length of the JaggedTensor. If not computed, return None.
941
-
942
- Returns:
943
- Optional[int]: the maximum length of the JaggedTensor.
944
- """
945
- return self ._max_length
946
-
947
915
def offsets (self ) -> torch .Tensor :
948
916
"""
949
917
Get JaggedTensor offsets. If not computed, compute it from lengths.
@@ -1005,7 +973,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
1005
973
weights = self ._weights
1006
974
lengths = self ._lengths
1007
975
offsets = self ._offsets
1008
- max_length = self ._max_length
1009
976
return JaggedTensor (
1010
977
values = self ._values .to (device , non_blocking = non_blocking ),
1011
978
weights = (
@@ -1023,7 +990,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
1023
990
if offsets is not None
1024
991
else None
1025
992
),
1026
- max_length = max_length ,
1027
993
)
1028
994
1029
995
@torch .jit .unused
0 commit comments