Skip to content

Commit f874ba2

Browse files
Siyuan Chengcarlescufi
authored andcommitted
zdsp: add in-place operaton test for basicmath
Add in-place calculatoin test for every functions in basicmath Signed-off-by: Siyuan Cheng <[email protected]>
1 parent b475e1f commit f874ba2

File tree

5 files changed

+1770
-0
lines changed

5 files changed

+1770
-0
lines changed

tests/subsys/dsp/basicmath/src/f16.c

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,38 @@ DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16, 23, in_com1, in_com2, ref_add
5151
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16, long, in_com1, in_com2, ref_add,
5252
ARRAY_SIZE(in_com1));
5353

54+
static void test_zdsp_add_f16_in_place(const uint16_t *input1, const uint16_t *input2,
55+
const uint16_t *ref, size_t length)
56+
{
57+
float16_t *output;
58+
59+
/* Allocate output buffer */
60+
output = malloc(length * sizeof(float16_t));
61+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
62+
63+
/* Copy input data to output*/
64+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
65+
66+
/* Run test function */
67+
zdsp_add_f16(output, (float16_t *)input2, output, length);
68+
69+
/* Validate output */
70+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
71+
ASSERT_MSG_SNR_LIMIT_EXCEED);
72+
73+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
74+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
75+
76+
/* Free output buffer */
77+
free(output);
78+
}
79+
80+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16_in_place, 7, in_com1, in_com2, ref_add, 7);
81+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16_in_place, 16, in_com1, in_com2, ref_add, 16);
82+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16_in_place, 23, in_com1, in_com2, ref_add, 23);
83+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_add_f16_in_place, long, in_com1, in_com2, ref_add,
84+
ARRAY_SIZE(in_com1));
85+
5486
static void test_zdsp_sub_f16(
5587
const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
5688
size_t length)
@@ -85,6 +117,38 @@ DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16, 23, in_com1, in_com2, ref_sub
85117
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16, long, in_com1, in_com2, ref_sub,
86118
ARRAY_SIZE(in_com1));
87119

120+
static void test_zdsp_sub_f16_in_place(const uint16_t *input1, const uint16_t *input2,
121+
const uint16_t *ref, size_t length)
122+
{
123+
float16_t *output;
124+
125+
/* Allocate output buffer */
126+
output = malloc(length * sizeof(float16_t));
127+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
128+
129+
/* Copy input data to output*/
130+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
131+
132+
/* Run test function */
133+
zdsp_sub_f16(output, (float16_t *)input2, output, length);
134+
135+
/* Validate output */
136+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
137+
ASSERT_MSG_SNR_LIMIT_EXCEED);
138+
139+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
140+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
141+
142+
/* Free output buffer */
143+
free(output);
144+
}
145+
146+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16_in_place, 7, in_com1, in_com2, ref_sub, 7);
147+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16_in_place, 16, in_com1, in_com2, ref_sub, 16);
148+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16_in_place, 23, in_com1, in_com2, ref_sub, 23);
149+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_sub_f16_in_place, long, in_com1, in_com2, ref_sub,
150+
ARRAY_SIZE(in_com1));
151+
88152
static void test_zdsp_mult_f16(
89153
const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
90154
size_t length)
@@ -119,6 +183,38 @@ DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16, 23, in_com1, in_com2, ref_mu
119183
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16, long, in_com1, in_com2, ref_mult,
120184
ARRAY_SIZE(in_com1));
121185

186+
static void test_zdsp_mult_f16_in_place(const uint16_t *input1, const uint16_t *input2,
187+
const uint16_t *ref, size_t length)
188+
{
189+
float16_t *output;
190+
191+
/* Allocate output buffer */
192+
output = malloc(length * sizeof(float16_t));
193+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
194+
195+
/* Copy input data to output*/
196+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
197+
198+
/* Run test function */
199+
zdsp_mult_f16(output, (float16_t *)input2, output, length);
200+
201+
/* Validate output */
202+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
203+
ASSERT_MSG_SNR_LIMIT_EXCEED);
204+
205+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
206+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
207+
208+
/* Free output buffer */
209+
free(output);
210+
}
211+
212+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16_in_place, 7, in_com1, in_com2, ref_mult, 7);
213+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16_in_place, 16, in_com1, in_com2, ref_mult, 16);
214+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16_in_place, 23, in_com1, in_com2, ref_mult, 23);
215+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_mult_f16_in_place, long, in_com1, in_com2, ref_mult,
216+
ARRAY_SIZE(in_com1));
217+
122218
static void test_zdsp_negate_f16(
123219
const uint16_t *input1, const uint16_t *ref, size_t length)
124220
{
@@ -152,6 +248,38 @@ DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16, 23, in_com1, ref_negate, 2
152248
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16, long, in_com1, ref_negate,
153249
ARRAY_SIZE(in_com1));
154250

251+
static void test_zdsp_negate_f16_in_place(const uint16_t *input1, const uint16_t *ref,
252+
size_t length)
253+
{
254+
float16_t *output;
255+
256+
/* Allocate output buffer */
257+
output = malloc(length * sizeof(float16_t));
258+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
259+
260+
/* Copy input data to output*/
261+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
262+
263+
/* Run test function */
264+
zdsp_negate_f16(output, output, length);
265+
266+
/* Validate output */
267+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
268+
ASSERT_MSG_SNR_LIMIT_EXCEED);
269+
270+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
271+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
272+
273+
/* Free output buffer */
274+
free(output);
275+
}
276+
277+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16_in_place, 7, in_com1, ref_negate, 7);
278+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16_in_place, 16, in_com1, ref_negate, 16);
279+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16_in_place, 23, in_com1, ref_negate, 23);
280+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_negate_f16_in_place, long, in_com1, ref_negate,
281+
ARRAY_SIZE(in_com1));
282+
155283
static void test_zdsp_offset_f16(
156284
const uint16_t *input1, float16_t scalar, const uint16_t *ref,
157285
size_t length)
@@ -186,6 +314,40 @@ DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16, 0p5_23, in_com1, 0.5f, ref
186314
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16, long, in_com1, 0.5f, ref_offset,
187315
ARRAY_SIZE(in_com1));
188316

317+
static void test_zdsp_offset_f16_in_place(const uint16_t *input1, float16_t scalar,
318+
const uint16_t *ref, size_t length)
319+
{
320+
float16_t *output;
321+
322+
/* Allocate output buffer */
323+
output = malloc(length * sizeof(float16_t));
324+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
325+
326+
/* Copy input data to output*/
327+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
328+
329+
/* Run test function */
330+
zdsp_offset_f16(output, scalar, output, length);
331+
332+
/* Validate output */
333+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
334+
ASSERT_MSG_SNR_LIMIT_EXCEED);
335+
336+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
337+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
338+
339+
/* Free output buffer */
340+
free(output);
341+
}
342+
343+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16_in_place, 0p5_7, in_com1, 0.5f, ref_offset, 7);
344+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16_in_place, 0p5_16, in_com1, 0.5f, ref_offset,
345+
16);
346+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16_in_place, 0p5_23, in_com1, 0.5f, ref_offset,
347+
23);
348+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_offset_f16_in_place, long, in_com1, 0.5f, ref_offset,
349+
ARRAY_SIZE(in_com1));
350+
189351
static void test_zdsp_scale_f16(
190352
const uint16_t *input1, float16_t scalar, const uint16_t *ref,
191353
size_t length)
@@ -220,6 +382,38 @@ DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16, 0p5_23, in_com1, 0.5f, ref_
220382
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16, long, in_com1, 0.5f, ref_scale,
221383
ARRAY_SIZE(in_com1));
222384

385+
static void test_zdsp_scale_f16_in_place(const uint16_t *input1, float16_t scalar,
386+
const uint16_t *ref, size_t length)
387+
{
388+
float16_t *output;
389+
390+
/* Allocate output buffer */
391+
output = malloc(length * sizeof(float16_t));
392+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
393+
394+
/* Copy input data to output*/
395+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
396+
397+
/* Run test function */
398+
zdsp_scale_f16(output, scalar, output, length);
399+
400+
/* Validate output */
401+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
402+
ASSERT_MSG_SNR_LIMIT_EXCEED);
403+
404+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
405+
ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
406+
407+
/* Free output buffer */
408+
free(output);
409+
}
410+
411+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16_in_place, 0p5_7, in_com1, 0.5f, ref_scale, 7);
412+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16_in_place, 0p5_16, in_com1, 0.5f, ref_scale, 16);
413+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16_in_place, 0p5_23, in_com1, 0.5f, ref_scale, 23);
414+
DEFINE_TEST_VARIANT4(basic_math_f16, zdsp_scale_f16_in_place, long, in_com1, 0.5f, ref_scale,
415+
ARRAY_SIZE(in_com1));
416+
223417
static void test_zdsp_dot_prod_f16(
224418
const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
225419
size_t length)
@@ -287,6 +481,37 @@ DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16, 16, in_com1, ref_abs, 16);
287481
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16, 23, in_com1, ref_abs, 23);
288482
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16, long, in_com1, ref_abs, ARRAY_SIZE(in_com1));
289483

484+
static void test_zdsp_abs_f16_in_place(const uint16_t *input1, const uint16_t *ref, size_t length)
485+
{
486+
float16_t *output;
487+
488+
/* Allocate output buffer */
489+
output = malloc(length * sizeof(float16_t));
490+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
491+
492+
/* Copy input data to output*/
493+
memcpy(output, (float16_t *)input1, length * sizeof(float16_t));
494+
495+
/* Run test function */
496+
zdsp_abs_f16(output, output, length);
497+
498+
/* Validate output */
499+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
500+
ASSERT_MSG_SNR_LIMIT_EXCEED);
501+
502+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
503+
"incorrect computation result");
504+
505+
/* Free output buffer */
506+
free(output);
507+
}
508+
509+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16_in_place, 7, in_com1, ref_abs, 7);
510+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16_in_place, 16, in_com1, ref_abs, 16);
511+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16_in_place, 23, in_com1, ref_abs, 23);
512+
DEFINE_TEST_VARIANT3(basic_math_f16, zdsp_abs_f16_in_place, long, in_com1, ref_abs,
513+
ARRAY_SIZE(in_com1));
514+
290515
static void test_zdsp_clip_f16(
291516
const uint16_t *input, const uint16_t *ref, float16_t min, float16_t max, size_t length)
292517
{
@@ -321,4 +546,37 @@ DEFINE_TEST_VARIANT5(basic_math_f16, zdsp_clip_f16, m0p5_0p5, in_clip, ref_clip2
321546
DEFINE_TEST_VARIANT5(basic_math_f16, zdsp_clip_f16, 0p1_0p5, in_clip, ref_clip3,
322547
0.1f, 0.5f, ARRAY_SIZE(ref_clip3));
323548

549+
static void test_zdsp_clip_f16_in_place(const uint16_t *input, const uint16_t *ref, float16_t min,
550+
float16_t max, size_t length)
551+
{
552+
float16_t *output;
553+
554+
/* Allocate output buffer */
555+
output = malloc(length * sizeof(float16_t));
556+
zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
557+
558+
/* Copy input data to output*/
559+
memcpy(output, (float16_t *)input, length * sizeof(float16_t));
560+
561+
/* Run test function */
562+
zdsp_clip_f16(output, output, min, max, length);
563+
564+
/* Validate output */
565+
zassert_true(test_snr_error_f16(length, output, (float16_t *)ref, SNR_ERROR_THRESH),
566+
ASSERT_MSG_SNR_LIMIT_EXCEED);
567+
568+
zassert_true(test_rel_error_f16(length, output, (float16_t *)ref, REL_ERROR_THRESH),
569+
"incorrect computation result");
570+
571+
/* Free output buffer */
572+
free(output);
573+
}
574+
575+
DEFINE_TEST_VARIANT5(basic_math_f16, zdsp_clip_f16_in_place, m0p5_m0p1, in_clip, ref_clip1, -0.5f,
576+
-0.1f, ARRAY_SIZE(ref_clip1));
577+
DEFINE_TEST_VARIANT5(basic_math_f16, zdsp_clip_f16_in_place, m0p5_0p5, in_clip, ref_clip2, -0.5f,
578+
0.5f, ARRAY_SIZE(ref_clip2));
579+
DEFINE_TEST_VARIANT5(basic_math_f16, zdsp_clip_f16_in_place, 0p1_0p5, in_clip, ref_clip3, 0.1f,
580+
0.5f, ARRAY_SIZE(ref_clip3));
581+
324582
ZTEST_SUITE(basic_math_f16, NULL, NULL, NULL, NULL, NULL);

0 commit comments

Comments
 (0)