@@ -680,7 +680,8 @@ def modalities(self):
680
680
return self ._items_by_modality .keys ()
681
681
682
682
@staticmethod
683
- def _try_stack (nested_tensors : NestedTensors ) -> NestedTensors :
683
+ def _try_stack (nested_tensors : NestedTensors ,
684
+ pin_memory : bool = False ) -> NestedTensors :
684
685
"""
685
686
Stack the inner dimensions that have the same shape in
686
687
a nested list of tensors.
@@ -697,7 +698,9 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
697
698
if isinstance (nested_tensors , (int , float )):
698
699
return torch .tensor (nested_tensors )
699
700
700
- stacked = [MultiModalKwargs ._try_stack (t ) for t in nested_tensors ]
701
+ stacked = [
702
+ MultiModalKwargs ._try_stack (t , pin_memory ) for t in nested_tensors
703
+ ]
701
704
if not is_list_of (stacked , torch .Tensor , check = "all" ):
702
705
# Only tensors (not lists) can be stacked.
703
706
return stacked
@@ -713,10 +716,16 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
713
716
# The tensors have incompatible shapes and can't be stacked.
714
717
return tensors_
715
718
716
- return torch .stack (tensors_ )
719
+ outputs = torch .empty (len (tensors_ ),
720
+ * tensors_ [0 ].shape ,
721
+ dtype = tensors_ [0 ].dtype ,
722
+ device = tensors_ [0 ].device ,
723
+ pin_memory = pin_memory )
724
+ return torch .stack (tensors_ , out = outputs )
717
725
718
726
@staticmethod
719
- def batch (inputs_list : list ["MultiModalKwargs" ]) -> BatchedTensorInputs :
727
+ def batch (inputs_list : list ["MultiModalKwargs" ],
728
+ pin_memory : bool = False ) -> BatchedTensorInputs :
720
729
"""
721
730
Batch multiple inputs together into a dictionary.
722
731
@@ -738,7 +747,7 @@ def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
738
747
item_lists [k ].append (v )
739
748
740
749
return {
741
- k : MultiModalKwargs ._try_stack (item_list )
750
+ k : MultiModalKwargs ._try_stack (item_list , pin_memory )
742
751
for k , item_list in item_lists .items ()
743
752
}
744
753
0 commit comments