Skip to content

Commit 9b1513e

Browse files
committed
[mlir][runner] Add more printMemref functions.
Add printMemref for complex data types and index type. Add printMemref for 1d type beyond f32. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D139475
1 parent 769c7ad commit 9b1513e

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

mlir/include/mlir/ExecutionEngine/RunnerUtils.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include <assert.h>
3535
#include <cmath>
36+
#include <complex>
3637
#include <iostream>
3738

3839
#include "mlir/ExecutionEngine/CRunnerUtils.h"
@@ -72,6 +73,10 @@ void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &v) {
7273
// Templated instantiation follows.
7374
////////////////////////////////////////////////////////////////////////////////
7475
namespace impl {
76+
using index_type = uint64_t;
77+
using complex64 = std::complex<double>;
78+
using complex32 = std::complex<float>;
79+
7580
template <typename T, int M, int... Dims>
7681
std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
7782

@@ -350,6 +355,12 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void
350355
_mlir_ciface_printMemrefShapeF32(UnrankedMemRefType<float> *m);
351356
extern "C" MLIR_RUNNERUTILS_EXPORT void
352357
_mlir_ciface_printMemrefShapeF64(UnrankedMemRefType<double> *m);
358+
extern "C" MLIR_RUNNERUTILS_EXPORT void
359+
_mlir_ciface_printMemrefShapeInd(UnrankedMemRefType<impl::index_type> *m);
360+
extern "C" MLIR_RUNNERUTILS_EXPORT void
361+
_mlir_ciface_printMemrefShapeC32(UnrankedMemRefType<impl::complex32> *m);
362+
extern "C" MLIR_RUNNERUTILS_EXPORT void
363+
_mlir_ciface_printMemrefShapeC64(UnrankedMemRefType<impl::complex64> *m);
353364

354365
extern "C" MLIR_RUNNERUTILS_EXPORT void
355366
_mlir_ciface_printMemrefI8(UnrankedMemRefType<int8_t> *m);
@@ -361,13 +372,22 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void
361372
_mlir_ciface_printMemrefF32(UnrankedMemRefType<float> *m);
362373
extern "C" MLIR_RUNNERUTILS_EXPORT void
363374
_mlir_ciface_printMemrefF64(UnrankedMemRefType<double> *m);
375+
extern "C" MLIR_RUNNERUTILS_EXPORT void
376+
_mlir_ciface_printMemrefInd(UnrankedMemRefType<impl::index_type> *m);
377+
extern "C" MLIR_RUNNERUTILS_EXPORT void
378+
_mlir_ciface_printMemrefC32(UnrankedMemRefType<impl::complex32> *m);
379+
extern "C" MLIR_RUNNERUTILS_EXPORT void
380+
_mlir_ciface_printMemrefC64(UnrankedMemRefType<impl::complex64> *m);
364381

365382
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_nanoTime();
366383

367384
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI32(int64_t rank, void *ptr);
368385
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI64(int64_t rank, void *ptr);
369386
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF32(int64_t rank, void *ptr);
370387
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF64(int64_t rank, void *ptr);
388+
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefInd(int64_t rank, void *ptr);
389+
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC32(int64_t rank, void *ptr);
390+
extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC64(int64_t rank, void *ptr);
371391
extern "C" MLIR_RUNNERUTILS_EXPORT void printCString(char *str);
372392

373393
extern "C" MLIR_RUNNERUTILS_EXPORT void
@@ -381,6 +401,21 @@ _mlir_ciface_printMemref3dF32(StridedMemRefType<float, 3> *m);
381401
extern "C" MLIR_RUNNERUTILS_EXPORT void
382402
_mlir_ciface_printMemref4dF32(StridedMemRefType<float, 4> *m);
383403

404+
extern "C" MLIR_RUNNERUTILS_EXPORT void
405+
_mlir_ciface_printMemref1dI8(StridedMemRefType<int8_t, 1> *m);
406+
extern "C" MLIR_RUNNERUTILS_EXPORT void
407+
_mlir_ciface_printMemref1dI32(StridedMemRefType<int32_t, 1> *m);
408+
extern "C" MLIR_RUNNERUTILS_EXPORT void
409+
_mlir_ciface_printMemref1dI64(StridedMemRefType<int64_t, 1> *m);
410+
extern "C" MLIR_RUNNERUTILS_EXPORT void
411+
_mlir_ciface_printMemref1dF64(StridedMemRefType<double, 1> *m);
412+
extern "C" MLIR_RUNNERUTILS_EXPORT void
413+
_mlir_ciface_printMemref1dInd(StridedMemRefType<impl::index_type, 1> *m);
414+
extern "C" MLIR_RUNNERUTILS_EXPORT void
415+
_mlir_ciface_printMemref1dC32(StridedMemRefType<impl::complex32, 1> *m);
416+
extern "C" MLIR_RUNNERUTILS_EXPORT void
417+
_mlir_ciface_printMemref1dC64(StridedMemRefType<impl::complex64, 1> *m);
418+
384419
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefVector4x4xf32(
385420
StridedMemRefType<Vector2D<4, 4, float>, 2> *m);
386421

@@ -390,6 +425,15 @@ extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF32(
390425
UnrankedMemRefType<float> *actual, UnrankedMemRefType<float> *expected);
391426
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF64(
392427
UnrankedMemRefType<double> *actual, UnrankedMemRefType<double> *expected);
428+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
429+
_mlir_ciface_verifyMemRefInd(UnrankedMemRefType<impl::index_type> *actual,
430+
UnrankedMemRefType<impl::index_type> *expected);
431+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
432+
_mlir_ciface_verifyMemRefC32(UnrankedMemRefType<impl::complex32> *actual,
433+
UnrankedMemRefType<impl::complex32> *expected);
434+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
435+
_mlir_ciface_verifyMemRefC64(UnrankedMemRefType<impl::complex64> *actual,
436+
UnrankedMemRefType<impl::complex64> *expected);
393437

394438
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefI32(int64_t rank,
395439
void *actualPtr,
@@ -400,5 +444,14 @@ extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF32(int64_t rank,
400444
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF64(int64_t rank,
401445
void *actualPtr,
402446
void *expectedPtr);
447+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefInd(int64_t rank,
448+
void *actualPtr,
449+
void *expectedPtr);
450+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC32(int64_t rank,
451+
void *actualPtr,
452+
void *expectedPtr);
453+
extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC64(int64_t rank,
454+
void *actualPtr,
455+
void *expectedPtr);
403456

404457
#endif // MLIR_EXECUTIONENGINE_RUNNERUTILS_H

mlir/lib/ExecutionEngine/RunnerUtils.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,27 @@ _mlir_ciface_printMemrefShapeF64(UnrankedMemRefType<double> *M) {
5151
std::cout << "\n";
5252
}
5353

54+
extern "C" void
55+
_mlir_ciface_printMemrefShapeInd(UnrankedMemRefType<impl::index_type> *M) {
56+
std::cout << "Unranked Memref ";
57+
printMemRefMetaData(std::cout, DynamicMemRefType<impl::index_type>(*M));
58+
std::cout << "\n";
59+
}
60+
61+
extern "C" void
62+
_mlir_ciface_printMemrefShapeC32(UnrankedMemRefType<impl::complex32> *M) {
63+
std::cout << "Unranked Memref ";
64+
printMemRefMetaData(std::cout, DynamicMemRefType<impl::complex32>(*M));
65+
std::cout << "\n";
66+
}
67+
68+
extern "C" void
69+
_mlir_ciface_printMemrefShapeC64(UnrankedMemRefType<impl::complex64> *M) {
70+
std::cout << "Unranked Memref ";
71+
printMemRefMetaData(std::cout, DynamicMemRefType<impl::complex64>(*M));
72+
std::cout << "\n";
73+
}
74+
5475
extern "C" void _mlir_ciface_printMemrefVector4x4xf32(
5576
StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
5677
impl::printMemRef(*M);
@@ -76,6 +97,21 @@ extern "C" void _mlir_ciface_printMemrefF64(UnrankedMemRefType<double> *M) {
7697
impl::printMemRef(*M);
7798
}
7899

100+
extern "C" void
101+
_mlir_ciface_printMemrefInd(UnrankedMemRefType<impl::index_type> *M) {
102+
impl::printMemRef(*M);
103+
}
104+
105+
extern "C" void
106+
_mlir_ciface_printMemrefC32(UnrankedMemRefType<impl::complex32> *M) {
107+
impl::printMemRef(*M);
108+
}
109+
110+
extern "C" void
111+
_mlir_ciface_printMemrefC64(UnrankedMemRefType<impl::complex64> *M) {
112+
impl::printMemRef(*M);
113+
}
114+
79115
extern "C" int64_t _mlir_ciface_nanoTime() {
80116
auto now = std::chrono::high_resolution_clock::now();
81117
auto duration = now.time_since_epoch();
@@ -104,6 +140,24 @@ extern "C" void printMemrefF64(int64_t rank, void *ptr) {
104140
_mlir_ciface_printMemrefF64(&descriptor);
105141
}
106142

143+
// Assume index_type is in fact uint64_t.
144+
static_assert(std::is_same<impl::index_type, uint64_t>::value,
145+
"Expected index_type == uint64_t");
146+
extern "C" void printMemrefInd(int64_t rank, void *ptr) {
147+
UnrankedMemRefType<impl::index_type> descriptor = {rank, ptr};
148+
_mlir_ciface_printMemrefInd(&descriptor);
149+
}
150+
151+
extern "C" void printMemrefC32(int64_t rank, void *ptr) {
152+
UnrankedMemRefType<impl::complex32> descriptor = {rank, ptr};
153+
_mlir_ciface_printMemrefC32(&descriptor);
154+
}
155+
156+
extern "C" void printMemrefC64(int64_t rank, void *ptr) {
157+
UnrankedMemRefType<impl::complex64> descriptor = {rank, ptr};
158+
_mlir_ciface_printMemrefC64(&descriptor);
159+
}
160+
107161
extern "C" void printCString(char *str) { printf("%s", str); }
108162

109163
extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
@@ -122,6 +176,39 @@ extern "C" void _mlir_ciface_printMemref4dF32(StridedMemRefType<float, 4> *M) {
122176
impl::printMemRef(*M);
123177
}
124178

179+
extern "C" void _mlir_ciface_printMemref1dI8(StridedMemRefType<int8_t, 1> *M) {
180+
impl::printMemRef(*M);
181+
}
182+
183+
extern "C" void
184+
_mlir_ciface_printMemref1dI32(StridedMemRefType<int32_t, 1> *M) {
185+
impl::printMemRef(*M);
186+
}
187+
188+
extern "C" void
189+
_mlir_ciface_printMemref1dI64(StridedMemRefType<int64_t, 1> *M) {
190+
impl::printMemRef(*M);
191+
}
192+
193+
extern "C" void _mlir_ciface_printMemref1dF64(StridedMemRefType<double, 1> *M) {
194+
impl::printMemRef(*M);
195+
}
196+
197+
extern "C" void
198+
_mlir_ciface_printMemref1dInd(StridedMemRefType<impl::index_type, 1> *M) {
199+
impl::printMemRef(*M);
200+
}
201+
202+
extern "C" void
203+
_mlir_ciface_printMemref1dC32(StridedMemRefType<impl::complex32, 1> *M) {
204+
impl::printMemRef(*M);
205+
}
206+
207+
extern "C" void
208+
_mlir_ciface_printMemref1dC64(StridedMemRefType<impl::complex64, 1> *M) {
209+
impl::printMemRef(*M);
210+
}
211+
125212
extern "C" int64_t
126213
_mlir_ciface_verifyMemRefI32(UnrankedMemRefType<int32_t> *actual,
127214
UnrankedMemRefType<int32_t> *expected) {
@@ -140,6 +227,24 @@ _mlir_ciface_verifyMemRefF64(UnrankedMemRefType<double> *actual,
140227
return impl::verifyMemRef(*actual, *expected);
141228
}
142229

230+
extern "C" int64_t
231+
_mlir_ciface_verifyMemRefInd(UnrankedMemRefType<impl::index_type> *actual,
232+
UnrankedMemRefType<impl::index_type> *expected) {
233+
return impl::verifyMemRef(*actual, *expected);
234+
}
235+
236+
extern "C" int64_t
237+
_mlir_ciface_verifyMemRefC32(UnrankedMemRefType<impl::complex32> *actual,
238+
UnrankedMemRefType<impl::complex32> *expected) {
239+
return impl::verifyMemRef(*actual, *expected);
240+
}
241+
242+
extern "C" int64_t
243+
_mlir_ciface_verifyMemRefC64(UnrankedMemRefType<impl::complex64> *actual,
244+
UnrankedMemRefType<impl::complex64> *expected) {
245+
return impl::verifyMemRef(*actual, *expected);
246+
}
247+
143248
extern "C" int64_t verifyMemRefI32(int64_t rank, void *actualPtr,
144249
void *expectedPtr) {
145250
UnrankedMemRefType<int32_t> actualDesc = {rank, actualPtr};
@@ -161,4 +266,25 @@ extern "C" int64_t verifyMemRefF64(int64_t rank, void *actualPtr,
161266
return _mlir_ciface_verifyMemRefF64(&actualDesc, &expectedDesc);
162267
}
163268

269+
extern "C" int64_t verifyMemRefInd(int64_t rank, void *actualPtr,
270+
void *expectedPtr) {
271+
UnrankedMemRefType<impl::index_type> actualDesc = {rank, actualPtr};
272+
UnrankedMemRefType<impl::index_type> expectedDesc = {rank, expectedPtr};
273+
return _mlir_ciface_verifyMemRefInd(&actualDesc, &expectedDesc);
274+
}
275+
276+
extern "C" int64_t verifyMemRefC32(int64_t rank, void *actualPtr,
277+
void *expectedPtr) {
278+
UnrankedMemRefType<impl::complex32> actualDesc = {rank, actualPtr};
279+
UnrankedMemRefType<impl::complex32> expectedDesc = {rank, expectedPtr};
280+
return _mlir_ciface_verifyMemRefC32(&actualDesc, &expectedDesc);
281+
}
282+
283+
extern "C" int64_t verifyMemRefC64(int64_t rank, void *actualPtr,
284+
void *expectedPtr) {
285+
UnrankedMemRefType<impl::complex64> actualDesc = {rank, actualPtr};
286+
UnrankedMemRefType<impl::complex64> expectedDesc = {rank, expectedPtr};
287+
return _mlir_ciface_verifyMemRefC64(&actualDesc, &expectedDesc);
288+
}
289+
164290
// NOLINTEND(*-identifier-naming)

0 commit comments

Comments
 (0)