Skip to content

Commit 0ed631d

Browse files
headlessNodekgryte
andauthored
feat: add dtype option support in ndarray/flatten
PR-URL: #8091 Closes: stdlib-js/metr-issue-tracker#79 Co-authored-by: Athan Reines <[email protected]> Reviewed-by: Athan Reines <[email protected]>
1 parent 66605d6 commit 0ed631d

File tree

6 files changed

+108
-5
lines changed

6 files changed

+108
-5
lines changed

lib/node_modules/@stdlib/ndarray/flatten/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ The function accepts the following options:
7272

7373
- **depth**: maximum number of input [ndarray][@stdlib/ndarray/ctor] dimensions to flatten.
7474

75+
- **dtype**: output ndarray [data type][@stdlib/ndarray/dtypes]. By default, the function returns an [ndarray][@stdlib/ndarray/ctor] having the same [data type][@stdlib/ndarray/dtypes] as a provided input [ndarray][@stdlib/ndarray/ctor].
76+
7577
By default, the function flattens all dimensions of the input [ndarray][@stdlib/ndarray/ctor]. To flatten to a desired depth, specify the `depth` option.
7678

7779
```javascript
@@ -108,6 +110,28 @@ var arr = ndarray2array( y );
108110
// returns [ 1.0, 3.0, 5.0, 2.0, 4.0, 6.0 ]
109111
```
110112

113+
By default, the output ndarray [data type][@stdlib/ndarray/dtypes] is inferred from the input [ndarray][@stdlib/ndarray/ctor]. To return an ndarray with a different [data type][@stdlib/ndarray/dtypes], specify the `dtype` option.
114+
115+
```javascript
116+
var array = require( '@stdlib/ndarray/array' );
117+
var dtype = require( '@stdlib/ndarray/dtype' );
118+
var ndarray2array = require( '@stdlib/ndarray/to-array' );
119+
120+
var x = array( [ [ [ 1.0, 2.0 ] ], [ [ 3.0, 4.0 ] ], [ [ 5.0, 6.0 ] ] ] );
121+
// returns <ndarray>
122+
123+
var y = flatten( x, {
124+
'dtype': 'float32'
125+
});
126+
// returns <ndarray>
127+
128+
var dt = dtype( y );
129+
// returns 'float32'
130+
131+
var arr = ndarray2array( y );
132+
// returns [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ]
133+
```
134+
111135
</section>
112136

113137
<!-- /.usage -->
@@ -164,6 +188,8 @@ console.log( ndarray2array( y ) );
164188

165189
[@stdlib/ndarray/ctor]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/ctor
166190

191+
[@stdlib/ndarray/dtypes]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/dtypes
192+
167193
[@stdlib/ndarray/orders]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/orders
168194

169195
<!-- <related-links> -->

lib/node_modules/@stdlib/ndarray/flatten/docs/repl.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
Default: 'row-major'.
3131

32+
options.dtype: string (optional)
33+
Output ndarray data type. By default, the function returns an ndarray
34+
having the same data type as the provided input ndarray.
35+
3236
Returns
3337
-------
3438
out: ndarray

lib/node_modules/@stdlib/ndarray/flatten/docs/types/index.d.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
/// <reference types="@stdlib/types"/>
2222

23-
import { ndarray, Order } from '@stdlib/types/ndarray';
23+
import { ndarray, Order, DataType } from '@stdlib/types/ndarray';
2424

2525
/**
2626
* Interface defining function options.
@@ -50,6 +50,15 @@ interface Options {
5050
* - Default: 'row-major'.
5151
*/
5252
order?: Order | 'same' | 'any';
53+
54+
/**
55+
* Output ndarray data type.
56+
*
57+
* ## Notes
58+
*
59+
* - By default, the function returns an ndarray having the same data type as a provided input ndarray.
60+
*/
61+
dtype?: DataType;
5362
}
5463

5564
/**

lib/node_modules/@stdlib/ndarray/flatten/docs/types/test.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,30 @@ import flatten = require( './index' );
128128
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'order': ( x: number ): number => x } ); // $ExpectError
129129
}
130130

131+
// The compiler throws an error if the function is provided a second argument with invalid `dtype` option...
132+
{
133+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': '5' } ); // $ExpectError
134+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': true } ); // $ExpectError
135+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': false } ); // $ExpectError
136+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': null } ); // $ExpectError
137+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': [ 1 ] } ); // $ExpectError
138+
flatten( zeros( 'float64', [ 2, 2, 2 ], 'row-major' ), { 'dtype': ( x: number ): number => x } ); // $ExpectError
139+
140+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': '5' } ); // $ExpectError
141+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': true } ); // $ExpectError
142+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': false } ); // $ExpectError
143+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': null } ); // $ExpectError
144+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': [ 1 ] } ); // $ExpectError
145+
flatten( zeros( 'complex128', [ 2, 2, 2 ], 'row-major' ), { 'dtype': ( x: number ): number => x } ); // $ExpectError
146+
147+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': '5' } ); // $ExpectError
148+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': true } ); // $ExpectError
149+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': false } ); // $ExpectError
150+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': null } ); // $ExpectError
151+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': [ 1 ] } ); // $ExpectError
152+
flatten( zeros( 'generic', [ 2, 2, 2 ], 'row-major' ), { 'dtype': ( x: number ): number => x } ); // $ExpectError
153+
}
154+
131155
// The compiler throws an error if the function is provided an unsupported number of arguments...
132156
{
133157
const x = zeros( 'float64', [ 2, 2, 2 ], 'row-major' );

lib/node_modules/@stdlib/ndarray/flatten/lib/main.js

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ var COL_MAJOR = 'column-major';
5454
* @param {Options} [options] - function options
5555
* @param {NonNegativeInteger} [options.depth] - maximum number of dimensions to flatten
5656
* @param {string} [options.order='row-major'] - order in which input ndarray elements should be flattened
57+
* @param {*} [options.dtype] - output ndarray data type
5758
* @throws {TypeError} first argument must be an ndarray-like object
5859
* @throws {TypeError} options argument must be an object
5960
* @throws {TypeError} must provide valid options
@@ -296,8 +297,9 @@ function flatten( x, options ) {
296297

297298
// Define default options:
298299
opts = {
299-
'depth': xsh.length, // by default, flatten to a one-dimensional ndarray
300-
'order': ROW_MAJOR // by default, flatten in lexicographic order (i.e., trailing dimensions first; e.g., if `x` is a matrix, flatten row-by-row)
300+
'depth': xsh.length, // by default, flatten to a one-dimensional ndarray
301+
'order': ROW_MAJOR, // by default, flatten in lexicographic order (i.e., trailing dimensions first; e.g., if `x` is a matrix, flatten row-by-row)
302+
'dtype': getDType( x )
301303
};
302304

303305
// Resolve function options...
@@ -335,16 +337,21 @@ function flatten( x, options ) {
335337
throw new TypeError( format( 'invalid option. `%s` option must be a recognized order. Option: `%s`.', 'order', options.order ) );
336338
}
337339
}
340+
if ( hasOwnProp( options, 'dtype' ) ) {
341+
// Delegate `dtype` validation to `emptyLike` during output array creation:
342+
opts.dtype = options.dtype;
343+
}
338344
}
339345
// Create an output ndarray having contiguous memory:
340346
y = emptyLike( x, {
341347
'shape': flattenShape( xsh, opts.depth ),
342-
'order': opts.order
348+
'order': opts.order,
349+
'dtype': opts.dtype
343350
});
344351

345352
// Create a view on top of output ndarray having the same shape as the input ndarray:
346353
st = ( xsh.length > 0 ) ? shape2strides( xsh, opts.order ) : [ 0 ];
347-
view = ndarray( getDType( y ), getData( y ), xsh, st, 0, opts.order );
354+
view = ndarray( opts.dtype, getData( y ), xsh, st, 0, opts.order );
348355

349356
// Copy elements to the output ndarray:
350357
assign( [ x, view ] );

lib/node_modules/@stdlib/ndarray/flatten/test/test.js

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,39 @@ tape( 'the function throws an error if provided an invalid `order` option', func
191191
}
192192
});
193193

194+
tape( 'the function throws an error if provided an invalid `dtype` option', function test( t ) {
195+
var values;
196+
var i;
197+
198+
values = [
199+
'foo',
200+
'bar',
201+
1,
202+
NaN,
203+
true,
204+
false,
205+
void 0,
206+
null,
207+
[],
208+
{},
209+
function noop() {}
210+
];
211+
212+
for ( i = 0; i < values.length; i++ ) {
213+
t.throws( badValue( values[i] ), TypeError, 'throws an error when provided '+ values[i] );
214+
}
215+
t.end();
216+
217+
function badValue( value ) {
218+
return function badValue() {
219+
var opts = {
220+
'dtype': value
221+
};
222+
flatten( zeros( [ 2 ] ), opts );
223+
};
224+
}
225+
});
226+
194227
tape( 'by default, the function flattens all dimensions of a provided input ndarray in lexicographic order (row-major, contiguous)', function test( t ) {
195228
var expected;
196229
var xbuf;

0 commit comments

Comments
 (0)