@@ -136,6 +136,25 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper)
136136 return result;
137137}
138138
139+ /*
140+ LAPACK query functions return workspace size as floating point value, which means
141+ that it might not be accurately represented if it's size exceed mantissa of the
142+ corresponding type. Fix it by adding 1ULP to the value before casting to it
143+ For more info see https://github.com/pytorch/pytorch/issues/145801#issuecomment-2631781776
144+ */
145+ template <typename T>
146+ static inline
147+ std::enable_if_t <std::is_floating_point_v<T>, int > lapack_work_to_int (const T val) {
148+ const auto next_after = std::nextafter (val, std::numeric_limits<T>::infinity ());
149+ return std::max<int >(1 , std::ceil (next_after));
150+ }
151+ template <typename T>
152+ static inline
153+ std::enable_if_t <c10::is_complex<T>::value, int > lapack_work_to_int (const T val) {
154+ return lapack_work_to_int (val.real ());
155+ }
156+
157+
139158/*
140159 Computes the eigenvalues and eigenvectors of n-by-n matrix 'input'.
141160 This is an in-place routine, content of 'input', 'values', 'vectors' is overwritten.
@@ -178,7 +197,7 @@ void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& in
178197 lapackEig<scalar_t , value_t >(jobvl, jobvr, n, input_data, lda, values_data,
179198 lvectors_data, ldvl, rvectors_data, ldvr, &work_query, -1 , rwork_data, &infos_data[0 ]);
180199
181- int lwork = std::max< int >( 1 , static_cast < int >(real_impl< scalar_t , value_t >( work_query)) );
200+ int lwork = lapack_work_to_int ( work_query);
182201 Tensor work = at::empty ({lwork}, input.dtype ());
183202 auto work_data = work.mutable_data_ptr <scalar_t >();
184203
@@ -218,6 +237,8 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
218237 'compute_eigenvectors' controls whether eigenvectors should be computed.
219238 This function doesn't do any error checks and it's assumed that every argument is valid.
220239*/
240+
241+
221242template <typename scalar_t >
222243void apply_lapack_eigh (const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
223244#if !AT_BUILD_WITH_LAPACK()
@@ -256,8 +277,7 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
256277 lapackSyevd<scalar_t , value_t >(jobz, uplo, n, vectors_data, lda, values_data,
257278 &lwork_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, infos_data);
258279
259- value_t next_after_lw = std::nextafter (real_impl<scalar_t , value_t >(lwork_query), std::numeric_limits<value_t >::infinity ());
260- lwork = std::max<int >(1 , std::ceil (next_after_lw));
280+ lwork = lapack_work_to_int (lwork_query);
261281
262282 Tensor work = at::empty ({lwork}, vectors.options ());
263283 auto work_data = work.mutable_data_ptr <scalar_t >();
@@ -269,8 +289,7 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
269289 Tensor rwork;
270290 value_t * rwork_data = nullptr ;
271291 if (vectors.is_complex ()) {
272- value_t next_after_rwork_query = std::nextafter (rwork_query, std::numeric_limits<value_t >::infinity ());
273- lrwork = std::max<int >(1 , std::ceil (next_after_rwork_query));
292+ lrwork = lapack_work_to_int (rwork_query);
274293 rwork = at::empty ({lrwork}, values.options ());
275294 rwork_data = rwork.mutable_data_ptr <value_t >();
276295 }
@@ -331,7 +350,6 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) {
331350 " Calling torch.geqrf on a CPU tensor requires compiling " ,
332351 " PyTorch with LAPACK. Please use PyTorch built with LAPACK support." );
333352#else
334- using value_t = typename c10::scalar_value_type<scalar_t >::type;
335353 auto input_data = input.data_ptr <scalar_t >();
336354 auto tau_data = tau.data_ptr <scalar_t >();
337355 auto input_matrix_stride = matrixStride (input);
@@ -353,7 +371,7 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) {
353371
354372 // if lwork is less than 'n' then a warning is printed:
355373 // Intel MKL ERROR: Parameter 7 was incorrect on entry to SGEQRF.
356- lwork = std::max<int >({ 1 , static_cast <int >(n), static_cast < int >(real_impl< scalar_t , value_t >( wkopt))} );
374+ lwork = std::max<int >(static_cast <int >(n), lapack_work_to_int ( wkopt));
357375 Tensor work = at::empty ({lwork}, input.options ());
358376
359377 for (const auto i : c10::irange (batch_size)) {
@@ -401,7 +419,6 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau) {
401419 return ;
402420 }
403421
404- using value_t = typename c10::scalar_value_type<scalar_t >::type;
405422 auto self_data = self.data_ptr <scalar_t >();
406423 auto tau_data = tau.const_data_ptr <scalar_t >();
407424 auto self_matrix_stride = matrixStride (self);
@@ -425,7 +442,7 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau) {
425442 scalar_t wkopt;
426443 lapackOrgqr<scalar_t >(m, n, k, self_data, lda, const_cast <scalar_t *>(tau_data), &wkopt, lwork, &info);
427444 TORCH_INTERNAL_ASSERT_DEBUG_ONLY (info == 0 );
428- lwork = std::max< int >( 1 , real_impl< scalar_t , value_t >( wkopt) );
445+ lwork = lapack_work_to_int ( wkopt);
429446 Tensor work = at::empty ({lwork}, self.options ());
430447
431448 for (const auto i : c10::irange (batch_size)) {
@@ -544,7 +561,7 @@ void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_valu
544561 s_working_ptr,
545562 &iwork_opt);
546563
547- lwork = std::max< int >( 1 , real_impl< scalar_t , value_t >( work_opt) );
564+ lwork = lapack_work_to_int ( work_opt);
548565 Tensor work = at::empty ({lwork}, A.options ());
549566 scalar_t * work_data = work.mutable_data_ptr <scalar_t >();
550567
@@ -1066,7 +1083,7 @@ static void apply_svd(const Tensor& A,
10661083 {
10671084 scalar_t wkopt;
10681085 lapackSvd<scalar_t , value_t >(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
1069- lwork = std::max< int >( 1 , real_impl< scalar_t , value_t >( wkopt) );
1086+ lwork = lapack_work_to_int ( wkopt);
10701087 }
10711088 auto work = std::vector<scalar_t >(lwork);
10721089 auto * const work_data = work.data ();
0 commit comments