Skip to content

Commit 9d238c9

Browse files
authored
[BLAS] Fix syrk test bug (#158)
* [BLAS] Fix syrk test bug * [BLAS] Limit error printing
1 parent 3cb60dd commit 9d238c9

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

tests/unit_tests/blas/batch/syrk_batch_stride.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
6363

6464
upper_lower = (oneapi::mkl::uplo)(std::rand() % 2);
6565
if ((std::is_same<fp, float>::value) || (std::is_same<fp, double>::value)) {
66-
trans = (oneapi::mkl::transpose)(std::rand() % 2);
66+
trans = (std::rand() % 2) == 0 ? oneapi::mkl::transpose::nontrans
67+
: (std::rand() % 2) == 0 ? oneapi::mkl::transpose::trans
68+
: oneapi::mkl::transpose::conjtrans;
6769
}
6870
else {
69-
tmp = std::rand() % 3;
70-
if (tmp == 2)
71-
trans = oneapi::mkl::transpose::conjtrans;
72-
else
73-
trans = (oneapi::mkl::transpose)tmp;
71+
trans = (std::rand() % 2) == 0 ? oneapi::mkl::transpose::nontrans
72+
: oneapi::mkl::transpose::trans;
7473
}
7574

7675
int64_t stride_a, stride_c;

tests/unit_tests/blas/include/test_common.hpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include <CL/sycl.hpp>
3030

31+
#define MAX_NUM_PRINT 20
32+
3133
namespace std {
3234
static sycl::half abs(sycl::half v) {
3335
if (v < sycl::half(0))
@@ -466,7 +468,7 @@ bool check_equal(fp x, fp x_ref, int error_mag, std::ostream &out) {
466468

467469
template <typename fp>
468470
bool check_equal_vector(fp *v, fp *v_ref, int n, int inc, int error_mag, std::ostream &out) {
469-
int abs_inc = std::abs(inc);
471+
int abs_inc = std::abs(inc), count = 0;
470472
bool good = true;
471473

472474
for (int i = 0; i < n; i++) {
@@ -475,6 +477,9 @@ bool check_equal_vector(fp *v, fp *v_ref, int n, int inc, int error_mag, std::os
475477
std::cout << "Difference in entry " << i_actual << ": DPC++ " << v[i * abs_inc]
476478
<< " vs. Reference " << v_ref[i * abs_inc] << std::endl;
477479
good = false;
480+
count++;
481+
if (count > MAX_NUM_PRINT)
482+
return good;
478483
}
479484
}
480485

@@ -483,7 +488,7 @@ bool check_equal_vector(fp *v, fp *v_ref, int n, int inc, int error_mag, std::os
483488

484489
template <typename vec1, typename vec2>
485490
bool check_equal_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag, std::ostream &out) {
486-
int abs_inc = std::abs(inc);
491+
int abs_inc = std::abs(inc), count = 0;
487492
bool good = true;
488493

489494
for (int i = 0; i < n; i++) {
@@ -492,6 +497,9 @@ bool check_equal_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag, std
492497
std::cout << "Difference in entry " << i_actual << ": DPC++ " << v[i * abs_inc]
493498
<< " vs. Reference " << v_ref[i * abs_inc] << std::endl;
494499
good = false;
500+
count++;
501+
if (count > MAX_NUM_PRINT)
502+
return good;
495503
}
496504
}
497505

@@ -501,7 +509,7 @@ bool check_equal_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag, std
501509
template <typename vec1, typename vec2>
502510
bool check_equal_trsv_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag,
503511
std::ostream &out) {
504-
int abs_inc = std::abs(inc);
512+
int abs_inc = std::abs(inc), count = 0;
505513
bool good = true;
506514

507515
for (int i = 0; i < n; i++) {
@@ -510,6 +518,9 @@ bool check_equal_trsv_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag
510518
std::cout << "Difference in entry " << i_actual << ": DPC++ " << v[i * abs_inc]
511519
<< " vs. Reference " << v_ref[i * abs_inc] << std::endl;
512520
good = false;
521+
count++;
522+
if (count > MAX_NUM_PRINT)
523+
return good;
513524
}
514525
}
515526

@@ -520,14 +531,17 @@ template <typename acc1, typename acc2>
520531
bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, int ld,
521532
int error_mag, std::ostream &out) {
522533
bool good = true;
523-
int idx;
534+
int idx, count = 0;
524535
for (int j = 0; j < n; j++) {
525536
for (int i = 0; i < m; i++) {
526537
idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld;
527538
if (!check_equal(M[idx], M_ref[idx], error_mag)) {
528539
out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx]
529540
<< " vs. Reference " << M_ref[idx] << std::endl;
530541
good = false;
542+
count++;
543+
if (count > MAX_NUM_PRINT)
544+
return good;
531545
}
532546
}
533547
}
@@ -539,14 +553,17 @@ template <typename fp>
539553
bool check_equal_matrix(fp *M, fp *M_ref, oneapi::mkl::layout layout, int m, int n, int ld,
540554
int error_mag, std::ostream &out) {
541555
bool good = true;
542-
int idx;
556+
int idx, count = 0;
543557
for (int j = 0; j < n; j++) {
544558
for (int i = 0; i < m; i++) {
545559
idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld;
546560
if (!check_equal(M[idx], M_ref[idx], error_mag)) {
547561
out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx]
548562
<< " vs. Reference " << M_ref[idx] << std::endl;
549563
good = false;
564+
count++;
565+
if (count > MAX_NUM_PRINT)
566+
return good;
550567
}
551568
}
552569
}
@@ -559,7 +576,7 @@ bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout,
559576
oneapi::mkl::uplo upper_lower, int m, int n, int ld, int error_mag,
560577
std::ostream &out) {
561578
bool good = true;
562-
int idx;
579+
int idx, count = 0;
563580
for (int j = 0; j < n; j++) {
564581
for (int i = 0; i < m; i++) {
565582
idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld;
@@ -569,6 +586,9 @@ bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout,
569586
out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx]
570587
<< " vs. Reference " << M_ref[idx] << std::endl;
571588
good = false;
589+
count++;
590+
if (count > MAX_NUM_PRINT)
591+
return good;
572592
}
573593
}
574594
}
@@ -581,14 +601,17 @@ template <typename acc1, typename acc2>
581601
bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, int ld,
582602
int error_mag, std::ostream &out) {
583603
bool good = true;
584-
int idx;
604+
int idx, count = 0;
585605
for (int j = 0; j < n; j++) {
586606
for (int i = 0; i < m; i++) {
587607
idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld;
588608
if (!check_equal_trsm(M[idx], M_ref[idx], error_mag)) {
589609
out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx]
590610
<< " vs. Reference " << M_ref[idx] << std::endl;
591611
good = false;
612+
count++;
613+
if (count > MAX_NUM_PRINT)
614+
return good;
592615
}
593616
}
594617
}

0 commit comments

Comments
 (0)