Skip to content
This repository was archived by the owner on Jul 4, 2023. It is now read-only.

Commit aa50d77

Browse files
authored
Merge pull request #58 from PetrochukM/pad_batch
Allow pad_batch concatenate along any dimension
2 parents 133a54c + aa98278 commit aa50d77

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

torchnlp/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,21 @@ def pad_tensor(tensor, length, padding_index=PADDING_INDEX):
9393
return torch.cat((tensor, padding), dim=0)
9494

9595

96-
def pad_batch(batch, padding_index=PADDING_INDEX):
96+
def pad_batch(batch, padding_index=PADDING_INDEX, dim=0):
9797
""" Pad a :class:`list` of ``tensors`` (``batch``) with ``padding_index``.
98+
9899
Args:
99100
batch (:class:`list` of :class:`torch.Tensor`): Batch of tensors to pad.
100101
padding_index (int, optional): Index to pad tensors with.
102+
dim (int, optional): Dimension on to which to concatenate the batch of tensors.
103+
101104
Returns
102105
torch.Tensor, list of int: Padded tensors and original lengths of tensors.
103106
"""
104107
lengths = [tensor.shape[0] for tensor in batch]
105108
max_len = max(lengths)
106109
padded = [pad_tensor(tensor, max_len, padding_index) for tensor in batch]
107-
padded = torch.stack(padded, dim=0).contiguous()
110+
padded = torch.stack(padded, dim=dim).contiguous()
108111
return padded, lengths
109112

110113

0 commit comments

Comments
 (0)