Skip to content

Commit 9e57d85

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Remove KJT GPU test (#2869)
Summary: Pull Request resolved: #2869 KJT's GPU test has been moved to its own test file. Reviewed By: TroyGarden Differential Revision: D72401855 fbshipit-source-id: 2ac75c73d8e89d0627cdd2f7f93f90ec01757be9
1 parent 89b12e6 commit 9e57d85

File tree

1 file changed

+1
-197
lines changed

1 file changed

+1
-197
lines changed

torchrec/sparse/tests/test_jagged_tensor_gpu.py

Lines changed: 1 addition & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
import unittest
1212

1313
import torch
14-
from torchrec.sparse.jagged_tensor import (
15-
_regroup_keyed_tensors,
16-
KeyedJaggedTensor,
17-
KeyedTensor,
18-
)
14+
from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor
1915
from torchrec.sparse.tests.utils import build_groups, build_kts
2016
from torchrec.test_utils import skip_if_asan_class
2117

@@ -115,195 +111,3 @@ def test_regroup_backward(self) -> None:
115111

116112
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
117113
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
118-
119-
120-
@skip_if_asan_class
121-
class TestKeyedJaggedTensorGPU(unittest.TestCase):
122-
def setUp(self) -> None:
123-
super().setUp()
124-
self.device = torch.cuda.current_device()
125-
126-
# pyre-ignore
127-
@unittest.skipIf(
128-
torch.cuda.device_count() <= 0,
129-
"Not enough GPUs, this test requires at least one GPUs",
130-
)
131-
def test_permute(self) -> None:
132-
values = torch.tensor(
133-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
134-
)
135-
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
136-
keys = ["index_0", "index_1", "index_2"]
137-
138-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
139-
values=values,
140-
keys=keys,
141-
lengths=lengths,
142-
)
143-
indices = [1, 0, 2]
144-
permuted_jag_tensor = jag_tensor.permute(indices)
145-
146-
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
147-
self.assertEqual(
148-
permuted_jag_tensor.offset_per_key(),
149-
[0, 3, 5, 8],
150-
)
151-
self.assertEqual(
152-
permuted_jag_tensor.values().tolist(),
153-
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
154-
)
155-
self.assertEqual(
156-
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
157-
)
158-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
159-
160-
# pyre-ignore
161-
@unittest.skipIf(
162-
torch.cuda.device_count() <= 0,
163-
"Not enough GPUs, this test requires at least one GPUs",
164-
)
165-
def test_permute_vb(self) -> None:
166-
values = torch.tensor(
167-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
168-
)
169-
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
170-
keys = ["index_0", "index_1", "index_2"]
171-
stride_per_key_per_rank = [[2], [4], [3]]
172-
173-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
174-
values=values,
175-
keys=keys,
176-
lengths=lengths,
177-
stride_per_key_per_rank=stride_per_key_per_rank,
178-
)
179-
180-
indices = [1, 0, 2]
181-
permuted_jag_tensor = jag_tensor.permute(indices)
182-
183-
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
184-
self.assertEqual(
185-
permuted_jag_tensor.offset_per_key(),
186-
[0, 5, 6, 8],
187-
)
188-
self.assertEqual(
189-
permuted_jag_tensor.values().tolist(),
190-
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
191-
)
192-
self.assertEqual(
193-
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
194-
)
195-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
196-
197-
# pyre-ignore
198-
@unittest.skipIf(
199-
torch.cuda.device_count() <= 0,
200-
"Not enough GPUs, this test requires at least one GPUs",
201-
)
202-
def test_permute_vb_duplicate(self) -> None:
203-
values = torch.tensor(
204-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
205-
)
206-
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
207-
keys = ["index_0", "index_1", "index_2"]
208-
stride_per_key_per_rank = [[2], [4], [3]]
209-
210-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
211-
values=values,
212-
keys=keys,
213-
lengths=lengths,
214-
stride_per_key_per_rank=stride_per_key_per_rank,
215-
)
216-
217-
indices = [1, 1, 0, 0, 2, 2]
218-
permuted_jag_tensor = jag_tensor.permute(indices)
219-
220-
self.assertEqual(
221-
permuted_jag_tensor.keys(),
222-
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
223-
)
224-
self.assertTrue(
225-
torch.equal(
226-
permuted_jag_tensor.values().cpu(),
227-
torch.Tensor(
228-
[
229-
2.0,
230-
3.0,
231-
4.0,
232-
5.0,
233-
6.0,
234-
2.0,
235-
3.0,
236-
4.0,
237-
5.0,
238-
6.0,
239-
1.0,
240-
1.0,
241-
7.0,
242-
8.0,
243-
7.0,
244-
8.0,
245-
]
246-
),
247-
)
248-
)
249-
self.assertTrue(
250-
torch.equal(
251-
permuted_jag_tensor.lengths().cpu(),
252-
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
253-
)
254-
)
255-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
256-
257-
# pyre-ignore
258-
@unittest.skipIf(
259-
torch.cuda.device_count() <= 0,
260-
"Not enough GPUs, this test requires at least one GPUs",
261-
)
262-
def test_permute_duplicates(self) -> None:
263-
values = torch.tensor(
264-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
265-
)
266-
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
267-
keys = ["index_0", "index_1", "index_2"]
268-
269-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
270-
values=values,
271-
keys=keys,
272-
lengths=lengths,
273-
)
274-
275-
indices = [1, 0, 2, 1, 1]
276-
permuted_jag_tensor = jag_tensor.permute(indices)
277-
278-
self.assertEqual(
279-
permuted_jag_tensor.keys(),
280-
["index_1", "index_0", "index_2", "index_1", "index_1"],
281-
)
282-
self.assertEqual(
283-
permuted_jag_tensor.offset_per_key(),
284-
[0, 3, 5, 8, 11, 14],
285-
)
286-
self.assertEqual(
287-
permuted_jag_tensor.values().tolist(),
288-
[
289-
3.0,
290-
4.0,
291-
5.0,
292-
1.0,
293-
2.0,
294-
6.0,
295-
7.0,
296-
8.0,
297-
3.0,
298-
4.0,
299-
5.0,
300-
3.0,
301-
4.0,
302-
5.0,
303-
],
304-
)
305-
self.assertEqual(
306-
permuted_jag_tensor.lengths().tolist(),
307-
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
308-
)
309-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

0 commit comments

Comments
 (0)