@@ -146,8 +146,11 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
146146 Raises:
147147 AssertionError: If results do not match.
148148 """
149+ np_spmd_result = _as_numpy (spmd_result )
149150
150- sorted_spmd_result = spmd_result [np .argsort (np .linalg .norm (spmd_result , axis = 1 ))]
151+ sorted_spmd_result = np_spmd_result [
152+ np .argsort (np .linalg .norm (np_spmd_result , axis = 1 ))
153+ ]
151154 if localize :
152155 local_batch_result = _get_local_tensor (batch_result )
153156 sorted_batch_result = local_batch_result [
@@ -158,7 +161,7 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
158161 np .argsort (np .linalg .norm (batch_result , axis = 1 ))
159162 ]
160163
161- assert_allclose (_as_numpy ( sorted_spmd_result ) , sorted_batch_result , ** kwargs )
164+ assert_allclose (sorted_spmd_result , sorted_batch_result , ** kwargs )
162165
163166
164167def _assert_kmeans_labels_allclose (
@@ -179,7 +182,11 @@ def _assert_kmeans_labels_allclose(
179182 AssertionError: If clusters are not correctly assigned.
180183 """
181184
185+ np_spmd_labels = _as_numpy (spmd_labels )
186+ np_spmd_centers = _as_numpy (spmd_centers )
182187 local_batch_labels = _get_local_tensor (batch_labels )
183188 assert_allclose (
184- spmd_centers [_as_numpy (spmd_labels )], batch_centers [local_batch_labels ], ** kwargs
189+ np_spmd_centers [np_spmd_labels ],
190+ batch_centers [local_batch_labels ],
191+ ** kwargs ,
185192 )
0 commit comments