You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -252,6 +255,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
252
255
varshx;
253
256
varshy;
254
257
vararr;
258
+
vartmp;
259
+
varxdt;
260
+
varydt;
255
261
vardt;
256
262
varf;
257
263
varN;
@@ -262,9 +268,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
262
268
if(!isndarrayLike(x)){
263
269
thrownewTypeError(format('invalid argument. First argument must be an ndarray-like object. Value: `%s`.',x));
264
270
}
265
-
dt=getDType(x);
266
-
if(!contains(this._idtypes[0],dt)){
267
-
thrownewTypeError(format('invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.',join(this._idtypes[0],'", "'),dt));
271
+
xdt=getDType(x);
272
+
if(!contains(this._idtypes[0],xdt)){
273
+
thrownewTypeError(format('invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.',join(this._idtypes[0],'", "'),xdt));
268
274
}
269
275
args=[x];
270
276
for(i=1;i<nargs;i++){
@@ -276,6 +282,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
276
282
if(!contains(this._idtypes[i],dt)){
277
283
thrownewTypeError(format('invalid argument. Argument %d must have one of the following data types: "%s". Data type: `%s`.',i,join(this._idtypes[i],'", "'),dt));
278
284
}
285
+
// Note: we don't type promote additional ndarray arguments, as they are passed as scalars to the underlying strided reduction function...
279
286
args.push(arr);
280
287
}
281
288
// If we didn't make it up until the last argument, this means that we found a non-options argument which was not an ndarray...
@@ -293,7 +300,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
293
300
throwerr;
294
301
}
295
302
}
296
-
// When a list of dimensions is not provided, reduce the entire input array across all dimensions...
303
+
// When a list of dimensions is not provided, reduce the entire input ndarray across all dimensions...
297
304
if(opts.dims===null){
298
305
opts.dims=zeroTo(N);
299
306
}
@@ -304,13 +311,24 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'apply', function apply( x ) {
304
311
shy=takeIndexed(shx,idx);
305
312
306
313
// Initialize an output array whose shape matches that of the non-reduced dimensions and which has the same memory layout as the input array:
// When performing an accumulation, such as a sum over many `int8` elements, we need to copy the input ndarray to a temporary workspace prior to performing a reduction whenever the promoted data type has a higher precision with the aim of guarding against overflow/underflow during intermediate computation (note: this follows similar guidance found in https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html)...
321
+
if(xdt!==ydt&&this._policy==='accumulation'){
322
+
dt=promotionRules(xdt,ydt);
323
+
if(dt!==-1&&xdt!==dt){// note: only perform the cast when an input data type promotes to an output data type; this can lead to divergence between, e.g., uint32-complex128 and uint32-complex64, where the former promotes, but the latter stays in uint32; however, we only get there if a user has specifically requested an output data type and who are we to question the user :|
324
+
tmp=baseEmpty(dt,shx,getOrder(x));
325
+
assign([x,tmp]);
326
+
args[0]=tmp;
327
+
xdt=dt;
328
+
}
329
+
}
312
330
// Resolve the lower-level strided function satisfying the input ndarray data type:
@@ -385,6 +403,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
385
403
vararr;
386
404
varerr;
387
405
varflg;
406
+
varxdt;
407
+
varydt;
408
+
vartmp;
388
409
vardt;
389
410
varN;
390
411
varf;
@@ -396,9 +417,9 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
396
417
thrownewTypeError(format('invalid argument. First argument must be an ndarray-like object. Value: `%s`.',x));
397
418
}
398
419
// Validate the input ndarray data type in order to maintain similar behavior to `apply` above...
399
-
dt=getDType(x);
400
-
if(!contains(this._idtypes[0],dt)){
401
-
thrownewTypeError(format('invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.',join(this._idtypes[0],'", "'),dt));
420
+
xdt=getDType(x);
421
+
if(!contains(this._idtypes[0],xdt)){
422
+
thrownewTypeError(format('invalid argument. First argument must have one of the following data types: "%s". Data type: `%s`.',join(this._idtypes[0],'", "'),xdt));
402
423
}
403
424
args=[x];
404
425
@@ -426,7 +447,7 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
426
447
// Cache a reference to the output ndarray:
427
448
y=args.pop();
428
449
429
-
// Verify that additional ndarray arguments have expected dtypes (note: we intentionally don't validate the output ndarray dtype in order to provide an escape hatch for a user wanting to have an output ndarray having a specific dtype that `apply` does not support)...
450
+
// Verify that additional ndarray arguments have expected dtypes (note: we intentionally don't validate the output ndarray dtype in order to provide an escape hatch for a user wanting to have an output ndarray having a specific dtype that `apply` does not support; note: we don't type promote additional ndarray arguments, as they are passed as scalars to the underlying strided reduction function)...
430
451
for(i=1;i<args.length;i++){
431
452
dt=getDType(args[i]);
432
453
if(!contains(this._idtypes[i],dt)){
@@ -446,8 +467,19 @@ setReadOnly( UnaryStrided1dDispatch.prototype, 'assign', function assign( x ) {
446
467
if(opts.dims===null){
447
468
opts.dims=zeroTo(N);
448
469
}
470
+
// When performing an accumulation, such as a sum over many `int8` elements, we need to copy the input ndarray to a temporary workspace prior to performing a reduction whenever the promoted data type has a higher precision with the aim of guarding against overflow/underflow during intermediate computation (note: this follows similar guidance found in https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html)...
471
+
ydt=getDType(y);
472
+
if(xdt!==ydt&&this._policy==='accumulation'){
473
+
dt=promotionRules(xdt,ydt);
474
+
if(dt!==-1&&xdt!==dt){// note: only perform the cast when an input data type promotes to an output data type; this can lead to divergence between, e.g., uint32-complex128 and uint32-complex64, where the former promotes, but the latter stays in uint32; however, we only get there if a user has specifically provided an output array with a data type which doesn't promote and who are we to question the user :|
475
+
tmp=baseEmpty(dt,getShape(x),getOrder(x));
476
+
assign([x,tmp]);
477
+
args[0]=tmp;
478
+
xdt=dt;
479
+
}
480
+
}
449
481
// Resolve the lower-level strided function satisfying the input ndarray data type:
0 commit comments