Skip to content

Commit 24125e3

Browse files
s-Nicknormallytangent
authored andcommitted
[BLAS][portBLAS] Add try/catch for portblas runtime exception & minor fix (#525)
Signed-off-by: nscipione <[email protected]> * Catch PortBLAS's unsupported exceptions and rethrow as mkl::unimplemented. * Add missing checks for the device having double support in tests.
1 parent 60bdce9 commit 24125e3

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

src/blas/backends/portblas/portblas_common.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,12 @@ struct throw_if_unsupported_by_device {
199199
auto fn = [](auto&&... targs) { \
200200
portBLASFunc(std::forward<decltype(targs)>(targs)...); \
201201
}; \
202-
std::apply(fn, args); \
202+
try { \
203+
std::apply(fn, args); \
204+
} \
205+
catch (const ::blas::unsupported_exception& e) { \
206+
throw unimplemented("blas", e.what()); \
207+
} \
203208
} \
204209
else { \
205210
throw unimplemented("blas", "portBLAS function"); \
@@ -215,7 +220,12 @@ struct throw_if_unsupported_by_device {
215220
auto fn = [](auto&&... targs) { \
216221
return portblasFunc(std::forward<decltype(targs)>(targs)...).back(); \
217222
}; \
218-
return std::apply(fn, args); \
223+
try { \
224+
return std::apply(fn, args); \
225+
} \
226+
catch (const ::blas::unsupported_exception& e) { \
227+
throw unimplemented("blas", e.what()); \
228+
} \
219229
} \
220230
else { \
221231
throw unimplemented("blas", "portBLAS function"); \

tests/unit_tests/blas/extensions/omatcopy2.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ TEST_P(Omatcopy2Tests, RealSinglePrecision) {
177177
}
178178

179179
TEST_P(Omatcopy2Tests, RealDoublePrecision) {
180+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
181+
180182
EXPECT_TRUEORSKIP(test<double>(std::get<0>(GetParam()), std::get<1>(GetParam())));
181183
}
182184

@@ -185,6 +187,8 @@ TEST_P(Omatcopy2Tests, ComplexSinglePrecision) {
185187
}
186188

187189
TEST_P(Omatcopy2Tests, ComplexDoublePrecision) {
190+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
191+
188192
EXPECT_TRUEORSKIP(test<std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam())));
189193
}
190194

tests/unit_tests/blas/extensions/omatcopy2_usm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ TEST_P(Omatcopy2UsmTests, RealSinglePrecision) {
186186
}
187187

188188
TEST_P(Omatcopy2UsmTests, RealDoublePrecision) {
189+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
190+
189191
EXPECT_TRUEORSKIP(test<double>(std::get<0>(GetParam()), std::get<1>(GetParam())));
190192
}
191193

@@ -194,6 +196,8 @@ TEST_P(Omatcopy2UsmTests, ComplexSinglePrecision) {
194196
}
195197

196198
TEST_P(Omatcopy2UsmTests, ComplexDoublePrecision) {
199+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
200+
197201
EXPECT_TRUEORSKIP(test<std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam())));
198202
}
199203

0 commit comments

Comments
 (0)