Skip to content

Commit 38de753

Browse files
committed
feat: add logic supporting an accumulation policy
--- type: pre_commit_static_analysis_report description: Results of running static analysis checks when committing changes. report: - task: lint_filenames status: passed - task: lint_editorconfig status: passed - task: lint_markdown status: na - task: lint_package_json status: na - task: lint_repl_help status: na - task: lint_javascript_src status: passed - task: lint_javascript_cli status: na - task: lint_javascript_examples status: na - task: lint_javascript_tests status: na - task: lint_javascript_benchmarks status: na - task: lint_python status: na - task: lint_r status: na - task: lint_c_src status: na - task: lint_c_examples status: na - task: lint_c_benchmarks status: na - task: lint_c_tests_fixtures status: na - task: lint_shell status: na - task: lint_typescript_declarations status: na - task: lint_typescript_tests status: na - task: lint_license_headers status: passed ---
1 parent 2d63924 commit 38de753

File tree

1 file changed

+43
-11
lines changed
  • lib/node_modules/@stdlib/ndarray/base/unary-reduce-strided1d-dispatch/lib

1 file changed

+43
-11
lines changed

lib/node_modules/@stdlib/ndarray/base/unary-reduce-strided1d-dispatch/lib/main.js

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ var getShape = require( '@stdlib/ndarray/shape' ); // note: non-base accessor is
4141
var ndims = require( '@stdlib/ndarray/ndims' );
4242
var getDType = require( '@stdlib/ndarray/base/dtype' );
4343
var getOrder = require( '@stdlib/ndarray/base/order' );
44+
var assign = require( '@stdlib/ndarray/base/assign' );
45+
var baseEmpty = require( '@stdlib/ndarray/base/empty' );
4446
var empty = require( '@stdlib/ndarray/empty' );
47+
var promotionRules = require( '@stdlib/ndarray/promotion-rules' );
4548
var indicesComplement = require( '@stdlib/array/base/indices-complement' );
4649
var takeIndexed = require( '@stdlib/array/base/take-indexed' );
4750
var zeroTo = require( '@stdlib/array/base/zero-to' );
@@ -252,6 +255,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
252255
var shx;
253256
var shy;
254257
var arr;
258+
var tmp;
259+
var xdt;
260+
var ydt;
255261
var dt;
256262
var f;
257263
var N;
@@ -262,9 +268,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
262268
if ( !isndarrayLike( x ) ) {
263269
throw new TypeError( format( 'invalid argument. First argument must be an ndarray-like object. Value: `%s`.', x ) );
264270
}
265-
dt = getDType( x );
266-
if ( !contains( this._idtypes[ 0 ], dt ) ) {
267-
throw new TypeError( format( 'invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.', join( this._idtypes[ 0 ], '", "' ), dt ) );
271+
xdt = getDType( x );
272+
if ( !contains( this._idtypes[ 0 ], xdt ) ) {
273+
throw new TypeError( format( 'invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.', join( this._idtypes[ 0 ], '", "' ), xdt ) );
268274
}
269275
args = [ x ];
270276
for ( i = 1; i < nargs; i++ ) {
@@ -276,6 +282,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
276282
if ( !contains( this._idtypes[ i ], dt ) ) {
277283
throw new TypeError( format( 'invalid argument. Argument %d must have one of the following data types: "%s". Data type: `%s`.', i, join( this._idtypes[ i ], '", "' ), dt ) );
278284
}
285+
// Note: we don't type promote additional ndarray arguments, as they are passed as scalars to the underlying strided reduction function...
279286
args.push( arr );
280287
}
281288
// If we didn't make it up until the last argument, this means that we found a non-options argument which was not an ndarray...
@@ -293,7 +300,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
293300
throw err;
294301
}
295302
}
296-
// When a list of dimensions is not provided, reduce the entire input array across all dimensions...
303+
// When a list of dimensions is not provided, reduce the entire input ndarray across all dimensions...
297304
if ( opts.dims === null ) {
298305
opts.dims = zeroTo( N );
299306
}
@@ -304,13 +311,24 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
304311
shy = takeIndexed( shx, idx );
305312

306313
// Initialize an output array whose shape matches that of the non-reduced dimensions and which has the same memory layout as the input array:
314+
ydt = opts.dtype || unaryOutputDataType( xdt, this._policy );
307315
y = empty( shy, {
308-
'dtype': opts.dtype || unaryOutputDataType( dt, this._policy ),
316+
'dtype': ydt,
309317
'order': getOrder( x )
310318
});
311319

320+
// When performing an accumulation, such as a sum over many `int8` elements, we need to copy the input ndarray to a temporary workspace prior to performing a reduction whenever the promoted data type has a higher precision with the aim of guarding against overflow/underflow during intermediate computation (note: this follows similar guidance found in https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html)...
321+
if ( xdt !== ydt && this._policy === 'accumulation' ) {
322+
dt = promotionRules( xdt, ydt );
323+
if ( dt !== -1 && xdt !== dt ) { // note: only perform the cast when an input data type promotes to an output data type; this can lead to divergence between, e.g., uint32-complex128 and uint32-complex64, where the former promotes, but the latter stays in uint32; however, we only get there if a user has specifically requested an output data type and who are we to question the user :|
324+
tmp = baseEmpty( dt, shx, getOrder( x ) );
325+
assign( [ x, tmp ] );
326+
args[ 0 ] = tmp;
327+
xdt = dt;
328+
}
329+
}
312330
// Resolve the lower-level strided function satisfying the input ndarray data type:
313-
dtypes = [ resolveEnum( dt ) ];
331+
dtypes = [ resolveEnum( xdt ) ];
314332
i = indexOfTypes( this._table.fcns.length, 1, this._table.types, 1, 1, 0, dtypes, 1, 0 ); // eslint-disable-line max-len
315333
if ( i >= 0 ) {
316334
f = this._table.fcns[ i ];
@@ -385,6 +403,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
385403
var arr;
386404
var err;
387405
var flg;
406+
var xdt;
407+
var ydt;
408+
var tmp;
388409
var dt;
389410
var N;
390411
var f;
@@ -396,9 +417,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
396417
throw new TypeError( format( 'invalid argument. First argument must be an ndarray-like object. Value: `%s`.', x ) );
397418
}
398419
// Validate the input ndarray data type in order to maintain similar behavior to `apply` above...
399-
dt = getDType( x );
400-
if ( !contains( this._idtypes[ 0 ], dt ) ) {
401-
throw new TypeError( format( 'invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.', join( this._idtypes[ 0 ], '", "' ), dt ) );
420+
xdt = getDType( x );
421+
if ( !contains( this._idtypes[ 0 ], xdt ) ) {
422+
throw new TypeError( format( 'invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.', join( this._idtypes[ 0 ], '", "' ), xdt ) );
402423
}
403424
args = [ x ];
404425

@@ -426,7 +447,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
426447
// Cache a reference to the output ndarray:
427448
y = args.pop();
428449

429-
// Verify that additional ndarray arguments have expected dtypes (note: we intentionally don't validate the output ndarray dtype in order to provide an escape hatch for a user wanting to have an output ndarray having a specific dtype that `apply` does not support)...
450+
// Verify that additional ndarray arguments have expected dtypes (note: we intentionally don't validate the output ndarray dtype in order to provide an escape hatch for a user wanting to have an output ndarray having a specific dtype that `apply` does not support; note: we don't type promote additional ndarray arguments, as they are passed as scalars to the underlying strided reduction function)...
430451
for ( i = 1; i < args.length; i++ ) {
431452
dt = getDType( args[ i ] );
432453
if ( !contains( this._idtypes[ i ], dt ) ) {
@@ -446,8 +467,19 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
446467
if ( opts.dims === null ) {
447468
opts.dims = zeroTo( N );
448469
}
470+
// When performing an accumulation, such as a sum over many `int8` elements, we need to copy the input ndarray to a temporary workspace prior to performing a reduction whenever the promoted data type has a higher precision with the aim of guarding against overflow/underflow during intermediate computation (note: this follows similar guidance found in https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html)...
471+
ydt = getDType( y );
472+
if ( xdt !== ydt && this._policy === 'accumulation' ) {
473+
dt = promotionRules( xdt, ydt );
474+
if ( dt !== -1 && xdt !== dt ) { // note: only perform the cast when an input data type promotes to an output data type; this can lead to divergence between, e.g., uint32-complex128 and uint32-complex64, where the former promotes, but the latter stays in uint32; however, we only get there if a user has specifically provided an output array with a data type which doesn't promote and who are we to question the user :|
475+
tmp = baseEmpty( dt, getShape( x ), getOrder( x ) );
476+
assign( [ x, tmp ] );
477+
args[ 0 ] = tmp;
478+
xdt = dt;
479+
}
480+
}
449481
// Resolve the lower-level strided function satisfying the input ndarray data type:
450-
dtypes = [ resolveEnum( dt ) ];
482+
dtypes = [ resolveEnum( xdt ) ];
451483
i = indexOfTypes( this._table.fcns.length, 1, this._table.types, 1, 1, 0, dtypes, 1, 0 ); // eslint-disable-line max-len
452484
if ( i >= 0 ) {
453485
f = this._table.fcns[ i ];

0 commit comments

Comments
 (0)