Skip to content

Commit 5327894

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Introduce lapack_work_to_int function (pytorch#149682)
That could be used to safely cast floating values to int by adding an ULP, which is a followup after pytorch#146456 Fixes pytorch#149591 (Not adding unittest as it's just going to be too slow) Test plan: ``` % python3 -c "import torch; torch.pinverse(torch.rand(50000, 8193))" ``` Before the change errored out with ``` RuntimeError: false INTERNAL ASSERT FAILED at "pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp":1605, please report a bug to PyTorch. linalg.svd: Argument 12 has illegal value. Most certainly there is a bug in the implementation calling the backend library. ``` Pull Request resolved: pytorch#149682 Approved by: https://github.com/wdvr
1 parent bf6621d commit 5327894

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,25 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper)
136136
return result;
137137
}
138138

139+
/*
140+
LAPACK query functions return workspace size as floating point value, which means
141+
that it might not be accurately represented if it's size exceed mantissa of the
142+
corresponding type. Fix it by adding 1ULP to the value before casting to it
143+
For more info see https://github.com/pytorch/pytorch/issues/145801#issuecomment-2631781776
144+
*/
145+
template <typename T>
146+
static inline
147+
std::enable_if_t<std::is_floating_point_v<T>, int> lapack_work_to_int(const T val) {
148+
const auto next_after = std::nextafter(val, std::numeric_limits<T>::infinity());
149+
return std::max<int>(1, std::ceil(next_after));
150+
}
151+
template <typename T>
152+
static inline
153+
std::enable_if_t<c10::is_complex<T>::value, int> lapack_work_to_int(const T val) {
154+
return lapack_work_to_int(val.real());
155+
}
156+
157+
139158
/*
140159
Computes the eigenvalues and eigenvectors of n-by-n matrix 'input'.
141160
This is an in-place routine, content of 'input', 'values', 'vectors' is overwritten.
@@ -178,7 +197,7 @@ void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& in
178197
lapackEig<scalar_t, value_t>(jobvl, jobvr, n, input_data, lda, values_data,
179198
lvectors_data, ldvl, rvectors_data, ldvr, &work_query, -1, rwork_data, &infos_data[0]);
180199

181-
int lwork = std::max<int>(1, static_cast<int>(real_impl<scalar_t, value_t>(work_query)));
200+
int lwork = lapack_work_to_int(work_query);
182201
Tensor work = at::empty({lwork}, input.dtype());
183202
auto work_data = work.mutable_data_ptr<scalar_t>();
184203

@@ -218,6 +237,8 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
218237
'compute_eigenvectors' controls whether eigenvectors should be computed.
219238
This function doesn't do any error checks and it's assumed that every argument is valid.
220239
*/
240+
241+
221242
template <typename scalar_t>
222243
void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
223244
#if !AT_BUILD_WITH_LAPACK()
@@ -256,8 +277,7 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
256277
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_data, lda, values_data,
257278
&lwork_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, infos_data);
258279

259-
value_t next_after_lw = std::nextafter(real_impl<scalar_t, value_t>(lwork_query), std::numeric_limits<value_t>::infinity());
260-
lwork = std::max<int>(1, std::ceil(next_after_lw));
280+
lwork = lapack_work_to_int(lwork_query);
261281

262282
Tensor work = at::empty({lwork}, vectors.options());
263283
auto work_data = work.mutable_data_ptr<scalar_t>();
@@ -269,8 +289,7 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
269289
Tensor rwork;
270290
value_t* rwork_data = nullptr;
271291
if (vectors.is_complex()) {
272-
value_t next_after_rwork_query = std::nextafter(rwork_query, std::numeric_limits<value_t>::infinity());
273-
lrwork = std::max<int>(1, std::ceil(next_after_rwork_query));
292+
lrwork = lapack_work_to_int(rwork_query);
274293
rwork = at::empty({lrwork}, values.options());
275294
rwork_data = rwork.mutable_data_ptr<value_t>();
276295
}
@@ -331,7 +350,6 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) {
331350
"Calling torch.geqrf on a CPU tensor requires compiling ",
332351
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
333352
#else
334-
using value_t = typename c10::scalar_value_type<scalar_t>::type;
335353
auto input_data = input.data_ptr<scalar_t>();
336354
auto tau_data = tau.data_ptr<scalar_t>();
337355
auto input_matrix_stride = matrixStride(input);
@@ -353,7 +371,7 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) {
353371

354372
// if lwork is less than 'n' then a warning is printed:
355373
// Intel MKL ERROR: Parameter 7 was incorrect on entry to SGEQRF.
356-
lwork = std::max<int>({1, static_cast<int>(n), static_cast<int>(real_impl<scalar_t, value_t>(wkopt))});
374+
lwork = std::max<int>(static_cast<int>(n), lapack_work_to_int(wkopt));
357375
Tensor work = at::empty({lwork}, input.options());
358376

359377
for (const auto i : c10::irange(batch_size)) {
@@ -401,7 +419,6 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau) {
401419
return;
402420
}
403421

404-
using value_t = typename c10::scalar_value_type<scalar_t>::type;
405422
auto self_data = self.data_ptr<scalar_t>();
406423
auto tau_data = tau.const_data_ptr<scalar_t>();
407424
auto self_matrix_stride = matrixStride(self);
@@ -425,7 +442,7 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau) {
425442
scalar_t wkopt;
426443
lapackOrgqr<scalar_t>(m, n, k, self_data, lda, const_cast<scalar_t*>(tau_data), &wkopt, lwork, &info);
427444
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
428-
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
445+
lwork = lapack_work_to_int(wkopt);
429446
Tensor work = at::empty({lwork}, self.options());
430447

431448
for (const auto i : c10::irange(batch_size)) {
@@ -544,7 +561,7 @@ void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_valu
544561
s_working_ptr,
545562
&iwork_opt);
546563

547-
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(work_opt));
564+
lwork = lapack_work_to_int(work_opt);
548565
Tensor work = at::empty({lwork}, A.options());
549566
scalar_t* work_data = work.mutable_data_ptr<scalar_t>();
550567

@@ -1066,7 +1083,7 @@ static void apply_svd(const Tensor& A,
10661083
{
10671084
scalar_t wkopt;
10681085
lapackSvd<scalar_t, value_t>(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
1069-
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
1086+
lwork = lapack_work_to_int(wkopt);
10701087
}
10711088
auto work = std::vector<scalar_t>(lwork);
10721089
auto* const work_data = work.data();

0 commit comments

Comments
 (0)