Skip to content

Commit d5e1c28

Browse files
committed
fix: maintain floating-point precision
--- type: pre_commit_static_analysis_report description: Results of running static analysis checks when committing changes. report: - task: lint_filenames status: passed - task: lint_editorconfig status: passed - task: lint_markdown status: na - task: lint_package_json status: na - task: lint_repl_help status: na - task: lint_javascript_src status: passed - task: lint_javascript_cli status: na - task: lint_javascript_examples status: na - task: lint_javascript_tests status: passed - task: lint_javascript_benchmarks status: na - task: lint_python status: na - task: lint_r status: na - task: lint_c_src status: na - task: lint_c_examples status: na - task: lint_c_benchmarks status: na - task: lint_c_tests_fixtures status: na - task: lint_shell status: na - task: lint_typescript_declarations status: na - task: lint_typescript_tests status: na - task: lint_license_headers status: passed ---
1 parent 86edcb5 commit d5e1c28

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

lib/node_modules/@stdlib/ndarray/base/output-dtype/lib/main.js

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ var isIntegerIndexDataType = require( '@stdlib/ndarray/base/assert/is-integer-in
3535
var isBooleanIndexDataType = require( '@stdlib/ndarray/base/assert/is-boolean-index-data-type' );
3636
var isMaskIndexDataType = require( '@stdlib/ndarray/base/assert/is-mask-index-data-type' );
3737
var isDataType = require( '@stdlib/ndarray/base/assert/is-data-type' );
38+
var isString = require( '@stdlib/assert/is-string' ).isPrimitive;
3839
var promoteDataTypes = require( '@stdlib/ndarray/base/promote-dtypes' );
3940
var defaults = require( '@stdlib/ndarray/defaults' );
4041
var join = require( '@stdlib/array/base/join' );
@@ -48,6 +49,7 @@ var DEFAULT_INDEX_DTYPE = defaults.get( 'dtypes.default_index' );
4849
var DEFAULT_SIGNED_INTEGER_DTYPE = defaults.get( 'dtypes.signed_integer' );
4950
var DEFAULT_UNSIGNED_INTEGER_DTYPE = defaults.get( 'dtypes.unsigned_integer' );
5051
var DEFAULT_REAL_FLOATING_POINT_DTYPE = defaults.get( 'dtypes.real_floating_point' );
52+
var DEFAULT_COMPLEX_FLOATING_POINT_DTYPE = defaults.get( 'dtypes.complex_floating_point' );
5153

5254
// Table where, for each respective policy, the value is a function which applies the policy to an input data type...
5355
var POLICY_TABLE1 = {
@@ -71,19 +73,19 @@ var POLICY_TABLE2 = {
7173
],
7274
'real_floating_point': [
7375
isRealFloatingPointDataType,
74-
DEFAULT_REAL_FLOATING_POINT_DTYPE
76+
resolveDefaultRealFloatingPoint
7577
],
7678
'real_floating_point_and_generic': [
7779
wrap( isRealFloatingPointDataType ),
78-
DEFAULT_REAL_FLOATING_POINT_DTYPE
80+
resolveDefaultRealFloatingPoint
7981
],
8082
'complex_floating_point': [
8183
isComplexFloatingPointDataType,
82-
defaults.get( 'dtypes.complex_floating_point' )
84+
resolveDefaultComplexFloatingPoint
8385
],
8486
'complex_floating_point_and_generic': [
8587
wrap( isComplexFloatingPointDataType ),
86-
defaults.get( 'dtypes.complex_floating_point' )
88+
resolveDefaultComplexFloatingPoint
8789
],
8890

8991
// Integer policies...
@@ -169,6 +171,20 @@ var POLICY_TABLE2 = {
169171
]
170172
};
171173

174+
// Table mapping complex-valued floating-point data types to real-valued floating-point data types having the same precision:
175+
var COMPLEX2FLOAT = {
176+
'complex128': 'float64',
177+
'complex64': 'float32',
178+
'complex32': 'float16'
179+
};
180+
181+
// Table mapping real-valued floating-point data types to complex-valued floating-point data types having the same precision:
182+
var FLOAT2COMPLEX = {
183+
'float64': 'complex128',
184+
'float32': 'complex64',
185+
'float16': 'complex32'
186+
};
187+
172188

173189
// FUNCTIONS //
174190

@@ -280,6 +296,28 @@ function accumulationPolicy( dtypes ) {
280296
return DEFAULT_REAL_FLOATING_POINT_DTYPE;
281297
}
282298

299+
/**
300+
* Resolves a default real-valued floating-point data type which preserves floating-point precision.
301+
*
302+
* @private
303+
* @param {string} dtype - input ndarray data type
304+
* @returns {string} output ndarray data type
305+
*/
306+
function resolveDefaultRealFloatingPoint( dtype ) {
307+
return COMPLEX2FLOAT[ dtype ] || DEFAULT_REAL_FLOATING_POINT_DTYPE;
308+
}
309+
310+
/**
311+
* Resolves a default complex-valued floating-point data type which preserves floating-point precision.
312+
*
313+
* @private
314+
* @param {string} dtype - input ndarray data type
315+
* @returns {string} output ndarray data type
316+
*/
317+
function resolveDefaultComplexFloatingPoint( dtype ) {
318+
return FLOAT2COMPLEX[ dtype ] || DEFAULT_COMPLEX_FLOATING_POINT_DTYPE;
319+
}
320+
283321

284322
// MAIN //
285323

@@ -320,8 +358,11 @@ function resolve( dtypes, policy ) {
320358
// If so, we can just return the promoted data type:
321359
return dt;
322360
}
323-
// Otherwise, we need to fallback to a default data type belonging to that "kind":
324-
return p[ 1 ];
361+
// Otherwise, we need to fallback to a default data type belonging to that "kind"...
362+
if ( isString( p[ 1 ] ) ) {
363+
return p[ 1 ];
364+
}
365+
return p[ 1 ]( dt );
325366
}
326367
throw new TypeError( format( 'invalid argument. Second argument must be a supported data type policy. Value: `%s`.', policy ) );
327368
}

lib/node_modules/@stdlib/ndarray/base/output-dtype/test/test.binary.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ tape( 'the function resolves an output data type (policy=complex_floating_point)
608608
dt = defaults.get( 'dtypes.complex_floating_point' );
609609
expected = [
610610
dt,
611-
dt,
611+
'complex64',
612612
dt,
613613
dt,
614614
dt,
@@ -651,7 +651,7 @@ tape( 'the function resolves an output data type (policy=complex_floating_point_
651651
dt = defaults.get( 'dtypes.complex_floating_point' );
652652
expected = [
653653
dt,
654-
dt,
654+
'complex64',
655655
dt,
656656
dt,
657657
'generic',

lib/node_modules/@stdlib/ndarray/base/output-dtype/test/test.unary.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ tape( 'the function resolves an output data type (policy=real_floating_point)',
352352
dt,
353353
dt,
354354
dt,
355-
dt
355+
'float32'
356356
];
357357
for ( i = 0; i < values.length; i++ ) {
358358
dt = resolve( [ values[ i ] ], 'real_floating_point' );
@@ -385,7 +385,7 @@ tape( 'the function resolves an output data type (policy=real_floating_point_and
385385
dt,
386386
'generic',
387387
dt,
388-
dt
388+
'float32'
389389
];
390390
for ( i = 0; i < values.length; i++ ) {
391391
dt = resolve( [ values[ i ] ], 'real_floating_point_and_generic' );
@@ -413,7 +413,7 @@ tape( 'the function resolves an output data type (policy=complex_floating_point)
413413
dt = defaults.get( 'dtypes.complex_floating_point' );
414414
expected = [
415415
dt,
416-
dt,
416+
'complex64',
417417
dt,
418418
dt,
419419
dt,
@@ -446,7 +446,7 @@ tape( 'the function resolves an output data type (policy=complex_floating_point_
446446
dt = defaults.get( 'dtypes.complex_floating_point' );
447447
expected = [
448448
dt,
449-
dt,
449+
'complex64',
450450
dt,
451451
dt,
452452
'generic',

0 commit comments

Comments
 (0)