Skip to content

Commit f43e4e5

Browse files
[BLAS] change reference iamin to avoid nan comparison issues/warnings (#159)
1 parent 9d238c9 commit f43e4e5

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

src/blas/backends/netlib/netlib_level1.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@ int cblas_isamin(int n, const float *x, int incx) {
4545
}
4646
int min_idx = 0;
4747
auto min_val = abs_val(x[0]);
48+
if (sycl::isnan(min_val))
49+
return 0;
4850

49-
for (int logical_i = 0; logical_i < n; ++logical_i) {
51+
for (int logical_i = 1; logical_i < n; ++logical_i) {
5052
int i = logical_i * std::abs(incx);
5153
auto curr_val = abs_val(x[i]);
52-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
53-
if (is_first_nan || curr_val < min_val) {
54+
if (sycl::isnan(curr_val))
55+
return logical_i;
56+
if (curr_val < min_val) {
5457
min_idx = logical_i;
5558
min_val = curr_val;
5659
}
@@ -64,12 +67,15 @@ int cblas_idamin(int n, const double *x, int incx) {
6467
}
6568
int min_idx = 0;
6669
auto min_val = abs_val(x[0]);
70+
if (sycl::isnan(min_val))
71+
return 0;
6772

68-
for (int logical_i = 0; logical_i < n; ++logical_i) {
73+
for (int logical_i = 1; logical_i < n; ++logical_i) {
6974
int i = logical_i * std::abs(incx);
7075
auto curr_val = abs_val(x[i]);
71-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
72-
if (is_first_nan || curr_val < min_val) {
76+
if (sycl::isnan(curr_val))
77+
return logical_i;
78+
if (curr_val < min_val) {
7379
min_idx = logical_i;
7480
min_val = curr_val;
7581
}
@@ -83,12 +89,15 @@ int cblas_icamin(int n, const std::complex<float> *x, int incx) {
8389
}
8490
int min_idx = 0;
8591
auto min_val = abs_val(x[0]);
92+
if (sycl::isnan(min_val))
93+
return 0;
8694

87-
for (int logical_i = 0; logical_i < n; ++logical_i) {
95+
for (int logical_i = 1; logical_i < n; ++logical_i) {
8896
int i = logical_i * std::abs(incx);
8997
auto curr_val = abs_val(x[i]);
90-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
91-
if (is_first_nan || curr_val < min_val) {
98+
if (sycl::isnan(curr_val))
99+
return logical_i;
100+
if (curr_val < min_val) {
92101
min_idx = logical_i;
93102
min_val = curr_val;
94103
}
@@ -102,12 +111,15 @@ int cblas_izamin(int n, const std::complex<double> *x, int incx) {
102111
}
103112
int min_idx = 0;
104113
auto min_val = abs_val(x[0]);
114+
if (sycl::isnan(min_val))
115+
return 0;
105116

106-
for (int logical_i = 0; logical_i < n; ++logical_i) {
117+
for (int logical_i = 1; logical_i < n; ++logical_i) {
107118
int i = logical_i * std::abs(incx);
108119
auto curr_val = abs_val(x[i]);
109-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
110-
if (is_first_nan || curr_val < min_val) {
120+
if (sycl::isnan(curr_val))
121+
return logical_i;
122+
if (curr_val < min_val) {
111123
min_idx = logical_i;
112124
min_val = curr_val;
113125
}

tests/unit_tests/blas/include/reference_blas_templates.hpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,12 +1461,15 @@ int iamin(const int *n, const float *x, const int *incx) {
14611461
}
14621462
int min_idx = 0;
14631463
auto min_val = abs_val(x[0]);
1464+
if (sycl::isnan(min_val))
1465+
return 0;
14641466

1465-
for (int logical_i = 0; logical_i < *n; ++logical_i) {
1467+
for (int logical_i = 1; logical_i < *n; ++logical_i) {
14661468
int i = logical_i * std::abs(*incx);
14671469
auto curr_val = abs_val(x[i]);
1468-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
1469-
if (is_first_nan || curr_val < min_val) {
1470+
if (sycl::isnan(curr_val))
1471+
return logical_i;
1472+
if (curr_val < min_val) {
14701473
min_idx = logical_i;
14711474
min_val = curr_val;
14721475
}
@@ -1481,12 +1484,15 @@ int iamin(const int *n, const double *x, const int *incx) {
14811484
}
14821485
int min_idx = 0;
14831486
auto min_val = abs_val(x[0]);
1487+
if (sycl::isnan(min_val))
1488+
return 0;
14841489

1485-
for (int logical_i = 0; logical_i < *n; ++logical_i) {
1490+
for (int logical_i = 1; logical_i < *n; ++logical_i) {
14861491
int i = logical_i * std::abs(*incx);
14871492
auto curr_val = abs_val(x[i]);
1488-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
1489-
if (is_first_nan || curr_val < min_val) {
1493+
if (sycl::isnan(curr_val))
1494+
return logical_i;
1495+
if (curr_val < min_val) {
14901496
min_idx = logical_i;
14911497
min_val = curr_val;
14921498
}
@@ -1501,12 +1507,15 @@ int iamin(const int *n, const std::complex<float> *x, const int *incx) {
15011507
}
15021508
int min_idx = 0;
15031509
auto min_val = abs_val(x[0]);
1510+
if (sycl::isnan(min_val))
1511+
return 0;
15041512

1505-
for (int logical_i = 0; logical_i < *n; ++logical_i) {
1513+
for (int logical_i = 1; logical_i < *n; ++logical_i) {
15061514
int i = logical_i * std::abs(*incx);
15071515
auto curr_val = abs_val(x[i]);
1508-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
1509-
if (is_first_nan || curr_val < min_val) {
1516+
if (sycl::isnan(curr_val))
1517+
return logical_i;
1518+
if (curr_val < min_val) {
15101519
min_idx = logical_i;
15111520
min_val = curr_val;
15121521
}
@@ -1521,12 +1530,15 @@ int iamin(const int *n, const std::complex<double> *x, const int *incx) {
15211530
}
15221531
int min_idx = 0;
15231532
auto min_val = abs_val(x[0]);
1533+
if (sycl::isnan(min_val))
1534+
return 0;
15241535

1525-
for (int logical_i = 0; logical_i < *n; ++logical_i) {
1536+
for (int logical_i = 1; logical_i < *n; ++logical_i) {
15261537
int i = logical_i * std::abs(*incx);
15271538
auto curr_val = abs_val(x[i]);
1528-
bool is_first_nan = std::isnan(curr_val) && !std::isnan(min_val);
1529-
if (is_first_nan || curr_val < min_val) {
1539+
if (sycl::isnan(curr_val))
1540+
return logical_i;
1541+
if (curr_val < min_val) {
15301542
min_idx = logical_i;
15311543
min_val = curr_val;
15321544
}

0 commit comments

Comments
 (0)