Skip to content

Commit f7563c4

Browse files
Fix: lstsq compilation with complex<float> inputs (#251)
This fix resolves a compilation error `gelsd.tcc:146:25: error: invalid conversion from 'const float*' to 'cxxlapack::FLOAT_COMPLEX*'` when linalg::lstsq is called with complex<float> arguments. Note that `complex<double>` and other types were unaffected. - A test case that demonstrates the issue is added to test/test_linalg.cpp - The fix adds const to the relevant interface to correctly match the expected input data type - The test compiles and passes with this fix --------- Co-authored-by: Johan Mabille <[email protected]>
1 parent 8570676 commit f7563c4

File tree

3 files changed

+49
-30
lines changed

3 files changed

+49
-30
lines changed

include/xflens/cxxlapack/netlib/interface/dummy.in.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,21 +1026,21 @@ LAPACK_DECL(cgels)(const char *TRANS,
10261026

10271027
//-- cgelsd --------------------------------------------------------------------
10281028
void
1029-
LAPACK_DECL(cgelsd)(const INTEGER *M,
1030-
const INTEGER *N,
1031-
const INTEGER *NRHS,
1032-
FLOAT_COMPLEX *A,
1033-
const INTEGER *LDA,
1034-
FLOAT_COMPLEX *B,
1035-
const INTEGER *LDB,
1036-
FLOAT *S,
1037-
const FLOAT *RCOND,
1038-
INTEGER *RANK,
1039-
FLOAT_COMPLEX *WORK,
1040-
const INTEGER *LWORK,
1041-
FLOAT *RWORK,
1042-
INTEGER *IWORK,
1043-
INTEGER *INFO)
1029+
LAPACK_DECL(cgelsd)(const INTEGER *M,
1030+
const INTEGER *N,
1031+
const INTEGER *NRHS,
1032+
const FLOAT_COMPLEX *A,
1033+
const INTEGER *LDA,
1034+
FLOAT_COMPLEX *B,
1035+
const INTEGER *LDB,
1036+
FLOAT *S,
1037+
const FLOAT *RCOND,
1038+
INTEGER *RANK,
1039+
FLOAT_COMPLEX *WORK,
1040+
const INTEGER *LWORK,
1041+
FLOAT *RWORK,
1042+
INTEGER *IWORK,
1043+
INTEGER *INFO)
10441044
{
10451045
DEBUG_LAPACK_STUB("cgelsd");
10461046
LAPACK_IMPL(cgelsd)(M,

include/xflens/cxxlapack/netlib/interface/lapack.in.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -579,21 +579,21 @@ LAPACK_IMPL(cgels)(const char *TRANS,
579579

580580
//-- cgelsd --------------------------------------------------------------------
581581
void
582-
LAPACK_IMPL(cgelsd)(const INTEGER *M,
583-
const INTEGER *N,
584-
const INTEGER *NRHS,
585-
FLOAT_COMPLEX *A,
586-
const INTEGER *LDA,
587-
FLOAT_COMPLEX *B,
588-
const INTEGER *LDB,
589-
FLOAT *S,
590-
const FLOAT *RCOND,
591-
INTEGER *RANK,
592-
FLOAT_COMPLEX *WORK,
593-
const INTEGER *LWORK,
594-
FLOAT *RWORK,
595-
INTEGER *IWORK,
596-
INTEGER *INFO);
582+
LAPACK_IMPL(cgelsd)(const INTEGER *M,
583+
const INTEGER *N,
584+
const INTEGER *NRHS,
585+
const FLOAT_COMPLEX *A,
586+
const INTEGER *LDA,
587+
FLOAT_COMPLEX *B,
588+
const INTEGER *LDB,
589+
FLOAT *S,
590+
const FLOAT *RCOND,
591+
INTEGER *RANK,
592+
FLOAT_COMPLEX *WORK,
593+
const INTEGER *LWORK,
594+
FLOAT *RWORK,
595+
INTEGER *IWORK,
596+
INTEGER *INFO);
597597

598598
//-- cgelss --------------------------------------------------------------------
599599
void

test/test_linalg.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ namespace xt
491491
CHECK(allclose(cel_1, std::get<1>(cres)));
492492
CHECK_EQ(cel_2, std::get<2>(cres));
493493
CHECK(allclose(cel_3, std::get<3>(cres)));
494+
495+
xarray<std::complex<float>> cfarg_0 = {{0.f, 1.f}, {1.f - 3if, 1.f}, {2.f, 1.f}, {3.f, 1.f}};
496+
xarray<std::complex<float>> cfarg_1 = {{-1.f, 0.2f + 4if, 0.9f, 2.1f - 1if}, {2.f, 3if, 2.f, 1.f}};
497+
cfarg_1 = transpose(cfarg_1);
498+
auto cfres = xt::linalg::lstsq(cfarg_0, cfarg_1);
499+
500+
xarray<std::complex<float>, layout_type::column_major> cfel_0 = {
501+
{-0.40425532f - 0.38723404if, -0.61702128f - 0.44680851if},
502+
{1.44680851f + 1.02765957if, 2.51063830f + 0.95744681if}
503+
};
504+
xarray<float> cfel_1 = {16.11787234f, 2.68085106f};
505+
int cfel_2 = 2;
506+
xarray<float> cfel_3 = {5.01295356f, 1.36758789f};
507+
508+
CHECK(allclose(imag(cfel_0), imag(std::get<0>(cfres))));
509+
CHECK(allclose(real(cfel_0), real(std::get<0>(cfres))));
510+
CHECK(allclose(cfel_1, std::get<1>(cfres)));
511+
CHECK_EQ(cfel_2, std::get<2>(cfres));
512+
CHECK(allclose(cfel_3, std::get<3>(cfres)));
494513
}
495514

496515
TEST_CASE("trace")

0 commit comments

Comments
 (0)