@@ -98,6 +98,15 @@ def _maybe_compute_lengths(
98
98
return lengths
99
99
100
100
101
+ @torch .fx .wrap
102
+ def _maybe_compute_max_length (lengths : torch .Tensor , max_length : Optional [int ]) -> int :
103
+ if max_length is None :
104
+ if lengths .numel () == 0 :
105
+ return 0
106
+ max_length = int (lengths .max ().item ())
107
+ return max_length
108
+
109
+
101
110
def _maybe_compute_offsets (
102
111
lengths : Optional [torch .Tensor ], offsets : Optional [torch .Tensor ]
103
112
) -> torch .Tensor :
@@ -581,14 +590,15 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
581
590
offsets.
582
591
"""
583
592
584
- _fields = ["_values" , "_weights" , "_lengths" , "_offsets" ]
593
+ _fields = ["_values" , "_weights" , "_lengths" , "_offsets" , "_max_length" ]
585
594
586
595
def __init__ (
587
596
self ,
588
597
values : torch .Tensor ,
589
598
weights : Optional [torch .Tensor ] = None ,
590
599
lengths : Optional [torch .Tensor ] = None ,
591
600
offsets : Optional [torch .Tensor ] = None ,
601
+ max_length : Optional [int ] = None ,
592
602
) -> None :
593
603
594
604
self ._values : torch .Tensor = values
@@ -600,6 +610,7 @@ def __init__(
600
610
_assert_tensor_has_no_elements_or_has_integers (lengths , "lengths" )
601
611
self ._lengths : Optional [torch .Tensor ] = lengths
602
612
self ._offsets : Optional [torch .Tensor ] = offsets
613
+ self ._max_length : Optional [int ] = max_length
603
614
604
615
@staticmethod
605
616
def empty (
@@ -630,6 +641,7 @@ def empty(
630
641
offsets = torch .empty (0 , dtype = lengths_dtype , device = device ),
631
642
lengths = torch .empty (0 , dtype = lengths_dtype , device = device ),
632
643
weights = weights ,
644
+ max_length = 0 ,
633
645
)
634
646
635
647
@staticmethod
@@ -912,6 +924,26 @@ def lengths_or_none(self) -> Optional[torch.Tensor]:
912
924
"""
913
925
return self ._lengths
914
926
927
+ def max_length (self ) -> int :
928
+ """
929
+ Get the maximum length of the JaggedTensor.
930
+
931
+ Returns:
932
+ int: the maximum length of the JaggedTensor.
933
+ """
934
+ _max_length = _maybe_compute_max_length (self .lengths (), self ._max_length )
935
+ self ._max_length = _max_length
936
+ return _max_length
937
+
938
+ def max_length_or_none (self ) -> Optional [int ]:
939
+ """
940
+ Get the maximum length of the JaggedTensor. If not computed, return None.
941
+
942
+ Returns:
943
+ Optional[int]: the maximum length of the JaggedTensor.
944
+ """
945
+ return self ._max_length
946
+
915
947
def offsets (self ) -> torch .Tensor :
916
948
"""
917
949
Get JaggedTensor offsets. If not computed, compute it from lengths.
@@ -973,6 +1005,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
973
1005
weights = self ._weights
974
1006
lengths = self ._lengths
975
1007
offsets = self ._offsets
1008
+ max_length = self ._max_length
976
1009
return JaggedTensor (
977
1010
values = self ._values .to (device , non_blocking = non_blocking ),
978
1011
weights = (
@@ -990,6 +1023,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
990
1023
if offsets is not None
991
1024
else None
992
1025
),
1026
+ max_length = max_length ,
993
1027
)
994
1028
995
1029
@torch .jit .unused
0 commit comments