Skip to content

Commit 6fe6a73

Browse files
headlessNodekgryte
andauthored
feat: add dtype option support in ndarray/flatten-by
PR-URL: #8094 Closes: stdlib-js/metr-issue-tracker#80 Co-authored-by: Athan Reines <[email protected]> Reviewed-by: Athan Reines <[email protected]>
1 parent 909a46e commit 6fe6a73

File tree

6 files changed

+214
-11
lines changed

6 files changed

+214
-11
lines changed

lib/node_modules/@stdlib/ndarray/flatten-by/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ The function accepts the following options:
7878

7979
- **depth**: maximum number of input [ndarray][@stdlib/ndarray/ctor] dimensions to flatten.
8080

81+
- **dtype**: output ndarray [data type][@stdlib/ndarray/dtypes]. By default, the function returns an [ndarray][@stdlib/ndarray/ctor] having the same [data type][@stdlib/ndarray/dtypes] as a provided input [ndarray][@stdlib/ndarray/ctor].
82+
8183
By default, the function flattens all dimensions of the input [ndarray][@stdlib/ndarray/ctor]. To flatten to a desired depth, specify the `depth` option.
8284

8385
```javascript
@@ -126,6 +128,33 @@ var arr = ndarray2array( y );
126128
// returns [ 2.0, 6.0, 10.0, 4.0, 8.0, 12.0 ]
127129
```
128130

131+
By default, the output ndarray [data type][@stdlib/ndarray/dtypes] is inferred from the input [ndarray][@stdlib/ndarray/ctor]. To return an ndarray with a different [data type][@stdlib/ndarray/dtypes], specify the `dtype` option.
132+
133+
```javascript
134+
var array = require( '@stdlib/ndarray/array' );
135+
var dtype = require( '@stdlib/ndarray/dtype' );
136+
var ndarray2array = require( '@stdlib/ndarray/to-array' );
137+
138+
function scale( value ) {
139+
return value * 2.0;
140+
}
141+
142+
var x = array( [ [ [ 1.0, 2.0 ] ], [ [ 3.0, 4.0 ] ], [ [ 5.0, 6.0 ] ] ] );
143+
// returns <ndarray>
144+
145+
var opts = {
146+
'dtype': 'float32'
147+
};
148+
var y = flattenBy( x, opts, scale );
149+
// returns <ndarray>
150+
151+
var dt = dtype( y );
152+
// returns 'float32'
153+
154+
var arr = ndarray2array( y );
155+
// returns [ 2.0, 4.0, 6.0, 8.0, 10.0, 12.0 ]
156+
```
157+
129158
To set the callback function execution context, provide a `thisArg`.
130159

131160
<!-- eslint-disable no-invalid-this, max-len -->
@@ -224,6 +253,8 @@ console.log( ndarray2array( y ) );
224253

225254
[@stdlib/ndarray/ctor]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/ctor
226255

256+
[@stdlib/ndarray/dtypes]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/dtypes
257+
227258
[@stdlib/ndarray/orders]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/orders
228259

229260
</section>

lib/node_modules/@stdlib/ndarray/flatten-by/docs/repl.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626

2727
Default: 'row-major'.
2828

29+
options.dtype: string (optional)
30+
Output ndarray data type. By default, the function returns an ndarray
31+
having the same data type as the provided input ndarray.
32+
2933
fcn: Function
3034
Callback function.
3135

lib/node_modules/@stdlib/ndarray/flatten-by/docs/types/index.d.ts

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
/// <reference types="@stdlib/types"/>
2222

23-
import { typedndarray, genericndarray, Order } from '@stdlib/types/ndarray';
23+
import { typedndarray, genericndarray, Order, DataTypeMap } from '@stdlib/types/ndarray';
2424
import { ComplexLike } from '@stdlib/types/complex';
2525

2626
/**
@@ -68,9 +68,9 @@ type Ternary<T, U, V, ThisArg> = ( this: ThisArg, value: T, indices: Array<numbe
6868
type Callback<T, U, V, ThisArg> = Nullary<V, ThisArg> | Unary<T, V, ThisArg> | Binary<T, V, ThisArg> | Ternary<T, U, V, ThisArg>;
6969

7070
/**
71-
* Interface defining function options.
71+
* Interface defining "base" function options.
7272
*/
73-
interface Options {
73+
interface BaseOptions {
7474
/**
7575
* Maximum number of dimensions to flatten.
7676
*
@@ -97,6 +97,16 @@ interface Options {
9797
order?: Order | 'same' | 'any';
9898
}
9999

100+
/**
101+
* Function options.
102+
*/
103+
type Options<U> = BaseOptions & {
104+
/**
105+
* Output ndarray data type.
106+
*/
107+
dtype: U;
108+
};
109+
100110
/**
101111
* Flattens an ndarray according to a callback function.
102112
*
@@ -232,6 +242,7 @@ declare function flattenBy<T = unknown, U extends genericndarray<T> = genericnda
232242
* @param options - function options
233243
* @param options.depth - maximum number of dimensions to flatten
234244
* @param options.order - order in which input ndarray elements should be flattened
245+
* @param options.dtype - output ndarray data type
235246
* @param fcn - callback function
236247
* @param thisArg - callback execution context
237248
* @returns output ndarray
@@ -263,7 +274,7 @@ declare function flattenBy<T = unknown, U extends genericndarray<T> = genericnda
263274
* var arr = ndarray2array( y );
264275
* // returns [ 2.0, 4.0, 6.0, 8.0, 10.0, 12.0 ]
265276
*/
266-
declare function flattenBy<T extends typedndarray<number> = typedndarray<number>, ThisArg = unknown>( x: T, options: Options, fcn: Callback<number, T, number, ThisArg>, thisArg?: ThisParameterType<Callback<number, T, number, ThisArg>> ): T;
277+
declare function flattenBy<T extends typedndarray<number> = typedndarray<number>, ThisArg = unknown>( x: T, options: BaseOptions, fcn: Callback<number, T, number, ThisArg>, thisArg?: ThisParameterType<Callback<number, T, number, ThisArg>> ): T;
267278

268279
/**
269280
* Flattens an ndarray according to a callback function.
@@ -272,6 +283,7 @@ declare function flattenBy<T extends typedndarray<number> = typedndarray<number>
272283
* @param options - function options
273284
* @param options.depth - maximum number of dimensions to flatten
274285
* @param options.order - order in which input ndarray elements should be flattened
286+
* @param options.dtype - output ndarray data type
275287
* @param fcn - callback function
276288
* @param thisArg - callback execution context
277289
* @returns output ndarray
@@ -300,7 +312,7 @@ declare function flattenBy<T extends typedndarray<number> = typedndarray<number>
300312
* var y = flattenBy( x, opts, identity );
301313
* // returns <ndarray>
302314
*/
303-
declare function flattenBy<T extends ComplexLike = ComplexLike, U extends typedndarray<T> = typedndarray<T>, ThisArg = unknown>( x: U, options: Options, fcn: Callback<T, U, T, ThisArg>, thisArg?: ThisParameterType<Callback<T, U, T, ThisArg>> ): U;
315+
declare function flattenBy<T extends ComplexLike = ComplexLike, U extends typedndarray<T> = typedndarray<T>, ThisArg = unknown>( x: U, options: BaseOptions, fcn: Callback<T, U, T, ThisArg>, thisArg?: ThisParameterType<Callback<T, U, T, ThisArg>> ): U;
304316

305317
/**
306318
* Flattens an ndarray according to a callback function.
@@ -309,6 +321,7 @@ declare function flattenBy<T extends ComplexLike = ComplexLike, U extends typedn
309321
* @param options - function options
310322
* @param options.depth - maximum number of dimensions to flatten
311323
* @param options.order - order in which input ndarray elements should be flattened
324+
* @param options.dtype - output ndarray data type
312325
* @param fcn - callback function
313326
* @param thisArg - callback execution context
314327
* @returns output ndarray
@@ -340,7 +353,7 @@ declare function flattenBy<T extends ComplexLike = ComplexLike, U extends typedn
340353
* var arr = ndarray2array( y );
341354
* // returns [ false, true, false, true, false, true ]
342355
*/
343-
declare function flattenBy<T extends typedndarray<boolean> = typedndarray<boolean>, ThisArg = unknown>( x: T, options: Options, fcn: Callback<boolean, T, boolean, ThisArg>, thisArg?: ThisParameterType<Callback<boolean, T, boolean, ThisArg>> ): T;
356+
declare function flattenBy<T extends typedndarray<boolean> = typedndarray<boolean>, ThisArg = unknown>( x: T, options: BaseOptions, fcn: Callback<boolean, T, boolean, ThisArg>, thisArg?: ThisParameterType<Callback<boolean, T, boolean, ThisArg>> ): T;
344357

345358
/**
346359
* Flattens an ndarray according to a callback function.
@@ -349,6 +362,7 @@ declare function flattenBy<T extends typedndarray<boolean> = typedndarray<boolea
349362
* @param options - function options
350363
* @param options.depth - maximum number of dimensions to flatten
351364
* @param options.order - order in which input ndarray elements should be flattened
365+
* @param options.dtype - output ndarray data type
352366
* @param fcn - callback function
353367
* @param thisArg - callback execution context
354368
* @returns output ndarray
@@ -379,7 +393,48 @@ declare function flattenBy<T extends typedndarray<boolean> = typedndarray<boolea
379393
* var arr = ndarray2array( y );
380394
* // returns [ 2.0, 4.0, 6.0, 8.0, 10.0, 12.0 ]
381395
*/
382-
declare function flattenBy<T = unknown, U extends genericndarray<T> = genericndarray<T>, V = unknown, W extends genericndarray<V> = genericndarray<V>, ThisArg = unknown>( x: U, options: Options, fcn: Callback<T, U, V, ThisArg>, thisArg?: ThisParameterType<Callback<T, U, V, ThisArg>> ): W;
396+
declare function flattenBy<T = unknown, U extends genericndarray<T> = genericndarray<T>, V = unknown, W extends genericndarray<V> = genericndarray<V>, ThisArg = unknown>( x: U, options: BaseOptions, fcn: Callback<T, U, V, ThisArg>, thisArg?: ThisParameterType<Callback<T, U, V, ThisArg>> ): W;
397+
398+
/**
399+
* Flattens an ndarray according to a callback function.
400+
*
401+
* @param x - input ndarray
402+
* @param options - function options
403+
* @param options.depth - maximum number of dimensions to flatten
404+
* @param options.order - order in which input ndarray elements should be flattened
405+
* @param options.dtype - output ndarray data type
406+
* @param fcn - callback function
407+
* @param thisArg - callback execution context
408+
* @returns output ndarray
409+
*
410+
* @example
411+
* var Float64Array = require( '@stdlib/array/float64' );
412+
* var ndarray = require( '@stdlib/ndarray/ctor' );
413+
* var ndarray2array = require( '@stdlib/ndarray/to-array' );
414+
*
415+
* function scale( value ) {
416+
* return value * 2.0;
417+
* }
418+
*
419+
* var buffer = new Float64Array( [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] );
420+
* var shape = [ 3, 1, 2 ];
421+
* var strides = [ 2, 2, 1 ];
422+
* var offset = 0;
423+
*
424+
* var x = ndarray( 'float64', buffer, shape, strides, offset, 'row-major' );
425+
* // return <ndarray>
426+
*
427+
* var opts = {
428+
* 'depth': 2
429+
* };
430+
*
431+
* var y = flattenBy( x, opts, scale );
432+
* // returns <ndarray>
433+
*
434+
* var arr = ndarray2array( y );
435+
* // returns [ 2.0, 4.0, 6.0, 8.0, 10.0, 12.0 ]
436+
*/
437+
declare function flattenBy<T = unknown, U extends typedndarray<T> | genericndarray<T> = typedndarray<T>, V = unknown, W extends keyof DataTypeMap<T> = 'generic', ThisArg = unknown>( x: U, options: Options<W>, fcn: Callback<T, U, V, ThisArg>, thisArg?: ThisParameterType<Callback<T, U, V, ThisArg>> ): DataTypeMap<V>[W];
383438

384439

385440
// EXPORTS //

lib/node_modules/@stdlib/ndarray/flatten-by/docs/types/test.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ function identity( x: any ): any {
8686
flattenBy( zeros( 'generic', sh, ord ), {}, identity ); // $ExpectType genericndarray<number>
8787
flattenBy( zeros( 'generic', sh, ord ), identity, {} ); // $ExpectType genericndarray<number>
8888
flattenBy( zeros( 'generic', sh, ord ), {}, identity, {} ); // $ExpectType genericndarray<number>
89+
90+
flattenBy( zeros( 'float64', sh, ord ), { 'dtype': 'float32' }, identity ); // $ExpectType float32ndarray
91+
flattenBy( zeros( 'float64', sh, ord ), { 'dtype': 'generic' }, identity ); // $ExpectType genericndarray<any>
92+
flattenBy( zeros( 'generic', sh, ord ), { 'dtype': 'float64' }, identity ); // $ExpectType float64ndarray
93+
flattenBy( zeros( 'generic', sh, ord ), { 'dtype': 'generic' }, identity ); // $ExpectType genericndarray<any>
8994
}
9095

9196
// The compiler throws an error if the function is provided a first argument which is not an ndarray...
@@ -178,6 +183,23 @@ function identity( x: any ): any {
178183
flattenBy( x, { 'order': [ 1 ] }, identity, {} ); // $ExpectError
179184
}
180185

186+
// The compiler throws an error if the function is provided a second argument with invalid `dtype` option...
187+
{
188+
const x = zeros( 'generic', [ 2, 2, 2 ], 'row-major' );
189+
190+
flattenBy( x, { 'dtype': '5' }, identity ); // $ExpectError
191+
flattenBy( x, { 'dtype': true }, identity ); // $ExpectError
192+
flattenBy( x, { 'dtype': false }, identity ); // $ExpectError
193+
flattenBy( x, { 'dtype': null }, identity ); // $ExpectError
194+
flattenBy( x, { 'dtype': [ 1 ] }, identity ); // $ExpectError
195+
196+
flattenBy( x, { 'dtype': '5' }, identity, {} ); // $ExpectError
197+
flattenBy( x, { 'dtype': true }, identity, {} ); // $ExpectError
198+
flattenBy( x, { 'dtype': false }, identity, {} ); // $ExpectError
199+
flattenBy( x, { 'dtype': null }, identity, {} ); // $ExpectError
200+
flattenBy( x, { 'dtype': [ 1 ] }, identity, {} ); // $ExpectError
201+
}
202+
181203
// The compiler throws an error if the function is provided a callback which is not a function...
182204
{
183205
const x = zeros( 'generic', [ 2, 2 ], 'row-major' );

lib/node_modules/@stdlib/ndarray/flatten-by/lib/main.js

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ var COL_MAJOR = 'column-major';
5555
* @param {Options} [options] - function options
5656
* @param {NonNegativeInteger} [options.depth] - maximum number of dimensions to flatten
5757
* @param {string} [options.order='row-major'] - order in which input ndarray elements should be flattened
58+
* @param {*} [options.dtype] - output ndarray data type
5859
* @param {Function} fcn - callback function
5960
* @param {*} [thisArg] - callback execution context
6061
* @throws {TypeError} first argument must be an ndarray-like object
@@ -101,8 +102,9 @@ function flattenBy( x, options, fcn, thisArg ) {
101102

102103
// Define default options:
103104
opts = {
104-
'depth': xsh.length, // by default, flatten to a one-dimensional ndarray
105-
'order': ROW_MAJOR // by default, flatten in lexicographic order (i.e., trailing dimensions first; e.g., if `x` is a matrix, flatten row-by-row)
105+
'depth': xsh.length, // by default, flatten to a one-dimensional ndarray
106+
'order': ROW_MAJOR, // by default, flatten in lexicographic order (i.e., trailing dimensions first; e.g., if `x` is a matrix, flatten row-by-row)
107+
'dtype': getDType( x )
106108
};
107109

108110
// Case: flattenBy( x, fcn )
@@ -165,16 +167,21 @@ function flattenBy( x, options, fcn, thisArg ) {
165167
throw new TypeError( format( 'invalid option. `%s` option must be a recognized order. Option: `%s`.', 'order', options.order ) );
166168
}
167169
}
170+
if ( hasOwnProp( options, 'dtype' ) ) {
171+
// Delegate `dtype` validation to `emptyLike` during output array creation:
172+
opts.dtype = options.dtype;
173+
}
168174
}
169175
// Create an output ndarray having contiguous memory:
170176
y = emptyLike( x, {
171177
'shape': flattenShape( xsh, opts.depth ),
172-
'order': opts.order
178+
'order': opts.order,
179+
'dtype': opts.dtype
173180
});
174181

175182
// Create a view on top of output ndarray having the same shape as the input ndarray:
176183
st = ( xsh.length > 0 ) ? shape2strides( xsh, opts.order ) : [ 0 ];
177-
view = ndarray( getDType( y ), getData( y ), xsh, st, 0, opts.order );
184+
view = ndarray( opts.dtype, getData( y ), xsh, st, 0, opts.order );
178185

179186
// Transform and assign elements to the output ndarray:
180187
map( [ x, view ], cb, ctx );

lib/node_modules/@stdlib/ndarray/flatten-by/test/test.js

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
var tape = require( 'tape' );
2626
var isSameFloat64Array = require( '@stdlib/assert/is-same-float64array' );
27+
var isSameFloat32Array = require( '@stdlib/assert/is-same-float32array' );
2728
var zeros = require( '@stdlib/ndarray/zeros' );
2829
var ndarray = require( '@stdlib/ndarray/ctor' );
2930
var Float64Array = require( '@stdlib/array/float64' );
31+
var Float32Array = require( '@stdlib/array/float32' );
3032
var identity = require( '@stdlib/number/float64/base/identity' );
3133
var getDType = require( '@stdlib/ndarray/dtype' );
3234
var getShape = require( '@stdlib/ndarray/shape' );
@@ -272,6 +274,39 @@ tape( 'the function throws an error if provided an invalid `order` option', func
272274
}
273275
});
274276

277+
tape( 'the function throws an error if provided an invalid `dtype` option', function test( t ) {
278+
var values;
279+
var i;
280+
281+
values = [
282+
'foo',
283+
'bar',
284+
1,
285+
NaN,
286+
true,
287+
false,
288+
void 0,
289+
null,
290+
[],
291+
{},
292+
function noop() {}
293+
];
294+
295+
for ( i = 0; i < values.length; i++ ) {
296+
t.throws( badValue( values[i] ), TypeError, 'throws an error when provided '+ values[i] );
297+
}
298+
t.end();
299+
300+
function badValue( value ) {
301+
return function badValue() {
302+
var opts = {
303+
'dtype': value
304+
};
305+
flattenBy( zeros( [ 2 ] ), opts, identity );
306+
};
307+
}
308+
});
309+
275310
tape( 'the function throws an error if provided a callback argument which is not a function', function test( t ) {
276311
var values;
277312
var i;
@@ -1534,6 +1569,55 @@ tape( 'the function supports flattening a one-dimensional input ndarray (order=a
15341569
t.end();
15351570
});
15361571

1572+
tape( 'the function supports specifying the output ndarray data type', function test( t ) {
1573+
var expected;
1574+
var xbuf;
1575+
var opts;
1576+
var ord;
1577+
var sh;
1578+
var st;
1579+
var dt;
1580+
var o;
1581+
var x;
1582+
var y;
1583+
1584+
dt = 'float64';
1585+
ord = 'row-major';
1586+
sh = [ 2, 2, 2 ];
1587+
st = shape2strides( sh, ord );
1588+
o = strides2offset( sh, st );
1589+
1590+
/*
1591+
* [
1592+
* [
1593+
* [ 1.0, 2.0 ],
1594+
* [ 3.0, 4.0 ]
1595+
* ],
1596+
* [
1597+
* [ 5.0, 6.0 ],
1598+
* [ 7.0, 8.0 ]
1599+
* ]
1600+
* ]
1601+
*/
1602+
xbuf = new Float64Array( [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ] );
1603+
x = new ndarray( dt, xbuf, sh, st, o, ord );
1604+
1605+
opts = {
1606+
'dtype': 'float32'
1607+
};
1608+
y = flattenBy( x, opts, identity );
1609+
expected = new Float32Array( [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ] );
1610+
1611+
t.notEqual( y, x, 'returns expected value' );
1612+
t.notEqual( getData( y ), xbuf, 'returns expected value' );
1613+
t.strictEqual( isSameFloat32Array( getData( y ), expected ), true, 'returns expected value' );
1614+
t.deepEqual( getShape( y ), [ 8 ], 'returns expected value' );
1615+
t.strictEqual( getDType( y ), 'float32', 'returns expected value' );
1616+
t.strictEqual( getOrder( y ), ord, 'returns expected value' );
1617+
1618+
t.end();
1619+
});
1620+
15371621
tape( 'the function supports specifying the callback execution context (row-major)', function test( t ) {
15381622
var expected;
15391623
var indices;

0 commit comments

Comments
 (0)