|
11 | 11 | import unittest
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| -from torchrec.sparse.jagged_tensor import ( |
15 |
| - _regroup_keyed_tensors, |
16 |
| - KeyedJaggedTensor, |
17 |
| - KeyedTensor, |
18 |
| -) |
| 14 | +from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor |
19 | 15 | from torchrec.sparse.tests.utils import build_groups, build_kts
|
20 | 16 | from torchrec.test_utils import skip_if_asan_class
|
21 | 17 |
|
@@ -115,195 +111,3 @@ def test_regroup_backward(self) -> None:
|
115 | 111 |
|
116 | 112 | torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
|
117 | 113 | torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
|
118 |
| - |
119 |
| - |
120 |
| -@skip_if_asan_class |
121 |
| -class TestKeyedJaggedTensorGPU(unittest.TestCase): |
122 |
| - def setUp(self) -> None: |
123 |
| - super().setUp() |
124 |
| - self.device = torch.cuda.current_device() |
125 |
| - |
126 |
| - # pyre-ignore |
127 |
| - @unittest.skipIf( |
128 |
| - torch.cuda.device_count() <= 0, |
129 |
| - "Not enough GPUs, this test requires at least one GPUs", |
130 |
| - ) |
131 |
| - def test_permute(self) -> None: |
132 |
| - values = torch.tensor( |
133 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
134 |
| - ) |
135 |
| - lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
136 |
| - keys = ["index_0", "index_1", "index_2"] |
137 |
| - |
138 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
139 |
| - values=values, |
140 |
| - keys=keys, |
141 |
| - lengths=lengths, |
142 |
| - ) |
143 |
| - indices = [1, 0, 2] |
144 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
145 |
| - |
146 |
| - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
147 |
| - self.assertEqual( |
148 |
| - permuted_jag_tensor.offset_per_key(), |
149 |
| - [0, 3, 5, 8], |
150 |
| - ) |
151 |
| - self.assertEqual( |
152 |
| - permuted_jag_tensor.values().tolist(), |
153 |
| - [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], |
154 |
| - ) |
155 |
| - self.assertEqual( |
156 |
| - permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] |
157 |
| - ) |
158 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
159 |
| - |
160 |
| - # pyre-ignore |
161 |
| - @unittest.skipIf( |
162 |
| - torch.cuda.device_count() <= 0, |
163 |
| - "Not enough GPUs, this test requires at least one GPUs", |
164 |
| - ) |
165 |
| - def test_permute_vb(self) -> None: |
166 |
| - values = torch.tensor( |
167 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
168 |
| - ) |
169 |
| - lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
170 |
| - keys = ["index_0", "index_1", "index_2"] |
171 |
| - stride_per_key_per_rank = [[2], [4], [3]] |
172 |
| - |
173 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
174 |
| - values=values, |
175 |
| - keys=keys, |
176 |
| - lengths=lengths, |
177 |
| - stride_per_key_per_rank=stride_per_key_per_rank, |
178 |
| - ) |
179 |
| - |
180 |
| - indices = [1, 0, 2] |
181 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
182 |
| - |
183 |
| - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
184 |
| - self.assertEqual( |
185 |
| - permuted_jag_tensor.offset_per_key(), |
186 |
| - [0, 5, 6, 8], |
187 |
| - ) |
188 |
| - self.assertEqual( |
189 |
| - permuted_jag_tensor.values().tolist(), |
190 |
| - [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], |
191 |
| - ) |
192 |
| - self.assertEqual( |
193 |
| - permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] |
194 |
| - ) |
195 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
196 |
| - |
197 |
| - # pyre-ignore |
198 |
| - @unittest.skipIf( |
199 |
| - torch.cuda.device_count() <= 0, |
200 |
| - "Not enough GPUs, this test requires at least one GPUs", |
201 |
| - ) |
202 |
| - def test_permute_vb_duplicate(self) -> None: |
203 |
| - values = torch.tensor( |
204 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
205 |
| - ) |
206 |
| - lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
207 |
| - keys = ["index_0", "index_1", "index_2"] |
208 |
| - stride_per_key_per_rank = [[2], [4], [3]] |
209 |
| - |
210 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
211 |
| - values=values, |
212 |
| - keys=keys, |
213 |
| - lengths=lengths, |
214 |
| - stride_per_key_per_rank=stride_per_key_per_rank, |
215 |
| - ) |
216 |
| - |
217 |
| - indices = [1, 1, 0, 0, 2, 2] |
218 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
219 |
| - |
220 |
| - self.assertEqual( |
221 |
| - permuted_jag_tensor.keys(), |
222 |
| - ["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"], |
223 |
| - ) |
224 |
| - self.assertTrue( |
225 |
| - torch.equal( |
226 |
| - permuted_jag_tensor.values().cpu(), |
227 |
| - torch.Tensor( |
228 |
| - [ |
229 |
| - 2.0, |
230 |
| - 3.0, |
231 |
| - 4.0, |
232 |
| - 5.0, |
233 |
| - 6.0, |
234 |
| - 2.0, |
235 |
| - 3.0, |
236 |
| - 4.0, |
237 |
| - 5.0, |
238 |
| - 6.0, |
239 |
| - 1.0, |
240 |
| - 1.0, |
241 |
| - 7.0, |
242 |
| - 8.0, |
243 |
| - 7.0, |
244 |
| - 8.0, |
245 |
| - ] |
246 |
| - ), |
247 |
| - ) |
248 |
| - ) |
249 |
| - self.assertTrue( |
250 |
| - torch.equal( |
251 |
| - permuted_jag_tensor.lengths().cpu(), |
252 |
| - torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]), |
253 |
| - ) |
254 |
| - ) |
255 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
256 |
| - |
257 |
| - # pyre-ignore |
258 |
| - @unittest.skipIf( |
259 |
| - torch.cuda.device_count() <= 0, |
260 |
| - "Not enough GPUs, this test requires at least one GPUs", |
261 |
| - ) |
262 |
| - def test_permute_duplicates(self) -> None: |
263 |
| - values = torch.tensor( |
264 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
265 |
| - ) |
266 |
| - lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
267 |
| - keys = ["index_0", "index_1", "index_2"] |
268 |
| - |
269 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
270 |
| - values=values, |
271 |
| - keys=keys, |
272 |
| - lengths=lengths, |
273 |
| - ) |
274 |
| - |
275 |
| - indices = [1, 0, 2, 1, 1] |
276 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
277 |
| - |
278 |
| - self.assertEqual( |
279 |
| - permuted_jag_tensor.keys(), |
280 |
| - ["index_1", "index_0", "index_2", "index_1", "index_1"], |
281 |
| - ) |
282 |
| - self.assertEqual( |
283 |
| - permuted_jag_tensor.offset_per_key(), |
284 |
| - [0, 3, 5, 8, 11, 14], |
285 |
| - ) |
286 |
| - self.assertEqual( |
287 |
| - permuted_jag_tensor.values().tolist(), |
288 |
| - [ |
289 |
| - 3.0, |
290 |
| - 4.0, |
291 |
| - 5.0, |
292 |
| - 1.0, |
293 |
| - 2.0, |
294 |
| - 6.0, |
295 |
| - 7.0, |
296 |
| - 8.0, |
297 |
| - 3.0, |
298 |
| - 4.0, |
299 |
| - 5.0, |
300 |
| - 3.0, |
301 |
| - 4.0, |
302 |
| - 5.0, |
303 |
| - ], |
304 |
| - ) |
305 |
| - self.assertEqual( |
306 |
| - permuted_jag_tensor.lengths().tolist(), |
307 |
| - [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], |
308 |
| - ) |
309 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
0 commit comments