Skip to content

Commit 2e42be0

Browse files
ngimelpytorchmergebot
authored andcommitted
1 parent 551f104 commit 2e42be0

File tree

4 files changed

+61
-19
lines changed

4 files changed

+61
-19
lines changed

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,29 +1322,48 @@ Tensor randn_like(
13221322
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13231323

13241324
namespace {
1325+
13251326
template <typename scalar_t>
13261327
void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) {
13271328
scalar_t* r__data = result.data_ptr<scalar_t>();
13281329

13291330
result.resize_({n});
13301331
int64_t r__stride_0 = result.stride(0);
13311332

1332-
at::parallel_for(
1333-
0,
1334-
n,
1335-
internal::GRAIN_SIZE,
1336-
[&r__data, &r__stride_0](int64_t p_begin, int64_t p_end) {
1337-
for (const auto i : c10::irange(p_begin, p_end)) {
1338-
r__data[i * r__stride_0] = static_cast<scalar_t>(i);
1339-
}
1340-
});
1341-
1342-
for (int64_t i = 0; i < n - 1; i++) {
1343-
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
1344-
int64_t z = generator->random() % (n - i);
1345-
scalar_t sav = r__data[i * r__stride_0];
1346-
r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0];
1347-
r__data[(z + i) * r__stride_0] = sav;
1333+
// for small n, preserve old behavior
1334+
if (n < std::numeric_limits<uint32_t>::max() / 20) {
1335+
at::parallel_for(
1336+
0,
1337+
n,
1338+
internal::GRAIN_SIZE,
1339+
[&r__data, &r__stride_0](int64_t p_begin, int64_t p_end) {
1340+
for (const auto i : c10::irange(p_begin, p_end)) {
1341+
r__data[i * r__stride_0] = static_cast<scalar_t>(i);
1342+
}
1343+
});
1344+
1345+
for (int64_t i = 0; i < n - 1; i++) {
1346+
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
1347+
int64_t z = generator->random() % (n - i);
1348+
scalar_t sav = r__data[i * r__stride_0];
1349+
r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0];
1350+
r__data[(z + i) * r__stride_0] = sav;
1351+
}
1352+
return;
1353+
}
1354+
1355+
// we need to pick a number uniformly distributed between 0 and n
1356+
// when n is of the same order of magnitude as the biggest number returned by
1357+
// random the % result is not uniformly distributed
1358+
// so we use random64(), you'd run out of RAM before you
1359+
// start seeing the skew
1360+
// use no-initialization Fischer-Yates variant
1361+
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_.22inside-out.22_algorithm
1362+
for (int64_t i = 0; i < n; i++) {
1363+
int64_t z = (int64_t)(generator->random64() % (i + 1));
1364+
r__data[i * r__stride_0] = i;
1365+
r__data[i * r__stride_0] = r__data[z * r__stride_0];
1366+
r__data[z * r__stride_0] = i;
13481367
}
13491368
}
13501369
} // namespace

test/test_sparse_csr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,7 +1956,7 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
19561956
@dtypesIfCUDA(*floating_and_complex_types_and(
19571957
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
19581958
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1959-
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1959+
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
19601960
def test_sparse_addmm(self, device, dtype):
19611961
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
19621962
if alpha_beta is None:
@@ -2617,7 +2617,7 @@ def run_test(m, n, k, nnz, train):
26172617
@skipIfTorchDynamo()
26182618
@onlyCPU
26192619
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2620-
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
2620+
@precisionOverride({torch.bfloat16: 0.02, torch.float16: 0.01})
26212621
def test_sparse_mm_reduce(self, device, dtype):
26222622
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
26232623
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)

test/test_tensor_creation_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3576,6 +3576,29 @@ def test_randperm(self, device):
35763576
self.assertEqual(non_contiguous_tensor, res)
35773577
self.assertEqual(res.sort().values.long(), torch.arange(n, device=device))
35783578

3579+
3580+
@largeTensorTest("10GB", "cpu")
3581+
@largeTensorTest("40GB", "cuda")
3582+
@slowTest
3583+
def test_randperm_large(self, device):
3584+
# Test even distribution where rand32 might produce skewed "uniform" distribution
3585+
# n_items is chosen to not evenly divide 2**32 and be sufficiently large
3586+
# to easily detect skew
3587+
def decile(index, collection_size):
3588+
return index // (collection_size // 10)
3589+
3590+
n_items = 700_000_000
3591+
shuffled = torch.randperm(n_items, device=device)
3592+
interval = 1_000_000
3593+
shuffled_interval = shuffled[:interval]
3594+
# histogram implemented for float only
3595+
deciles = decile(shuffled_interval, shuffled.shape[0]).float().cpu()
3596+
hist, _ = deciles.histogram(10, range=(0, 10))
3597+
expected_bin = shuffled_interval.shape[0] / 10
3598+
expected_error = math.sqrt(expected_bin) / expected_bin * 3
3599+
error = (hist - expected_bin).abs().max() / expected_bin
3600+
self.assertTrue(error < expected_error, f"error {error} > {expected_error}")
3601+
35793602
# Test exceptions when device and generator types are incompatible
35803603
@onlyCUDA
35813604
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.")

test/torch_np/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_1d(self, use_numpy):
8787
@parametrize("use_numpy", [True, False])
8888
def test_2d(self, use_numpy):
8989
# np.shuffle only shuffles the first axis
90-
ax = tnp.asarray([[1, 2, 3], [4, 5, 6]])
90+
ax = tnp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
9191
ox = ax.copy()
9292

9393
tnp.random.seed(1234)

0 commit comments

Comments
 (0)